You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_scheduler.cc 91 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/framework/graph_scheduler.h"
  17. #include "runtime/framework/actor/memory_manager_actor.h"
  18. #include "runtime/framework/actor/debug_actor.h"
  19. #include "runtime/framework/actor/recorder_actor.h"
  20. #include "runtime/hardware/device_context_manager.h"
  21. #include "mindrt/src/actor/actormgr.h"
  22. #include "mindrt/include/async/async.h"
  23. #include "backend/session/anf_runtime_algorithm.h"
  24. #include "backend/optimizer/common/helper.h"
  25. #include "utils/config_manager.h"
  26. #include "utils/log_adapter.h"
  27. #include "utils/convert_utils.h"
  28. #include "utils/ms_context.h"
  29. #include "utils/profile.h"
  30. #if !defined(_WIN32) && !defined(_WIN64)
  31. #include "utils/signal_util.h"
  32. #endif
  33. #ifndef ENABLE_SECURITY
  34. #include "debug/data_dump/dump_json_parser.h"
  35. #endif
  36. #ifdef ENABLE_DUMP_IR
  37. #include "debug/rdr/recorder_manager.h"
  38. #include "debug/rdr/running_data_recorder.h"
  39. #endif
  40. #ifdef ENABLE_DEBUGGER
  41. #include "debug/debugger/debugger.h"
  42. #endif
  43. #include "profiler/device/profiling.h"
  44. #include "debug/common.h"
  45. namespace mindspore {
  46. namespace runtime {
  47. namespace {
  48. bool IsNeedInsertCopyActor(const DeviceContext *from_device_context, const DeviceContext *to_device_context) {
  49. MS_EXCEPTION_IF_NULL(from_device_context);
  50. MS_EXCEPTION_IF_NULL(to_device_context);
  51. if (from_device_context->GetDeviceAddressType() == to_device_context->GetDeviceAddressType()) {
  52. return false;
  53. } else {
  54. return true;
  55. }
  56. }
  57. inline bool IsSingleOpActorSet(const ActorSet *actor_set) {
  58. MS_EXCEPTION_IF_NULL(actor_set);
  59. return actor_set->kernel_actors_.size() == 1;
  60. }
  61. // Convert the actors vector by the actor set.
  62. std::vector<AbstractActorPtr> CollectActors(const ActorSet *actor_set) {
  63. MS_EXCEPTION_IF_NULL(actor_set);
  64. std::vector<AbstractActorPtr> actors;
  65. if (actor_set->data_prepare_actor_ != nullptr) {
  66. (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->data_prepare_actor_));
  67. }
  68. for (auto &data_source_actor : actor_set->data_source_actors_) {
  69. MS_EXCEPTION_IF_NULL(data_source_actor);
  70. (void)actors.emplace_back(static_cast<AbstractActorPtr>(data_source_actor));
  71. }
  72. for (auto &kernel_actor : actor_set->kernel_actors_) {
  73. MS_EXCEPTION_IF_NULL(kernel_actor);
  74. (void)actors.emplace_back(static_cast<AbstractActorPtr>(kernel_actor));
  75. }
  76. for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
  77. MS_EXCEPTION_IF_NULL(super_kernel_actor);
  78. (void)actors.emplace_back(static_cast<AbstractActorPtr>(super_kernel_actor));
  79. }
  80. for (auto &copy_actor : actor_set->copy_actors_) {
  81. MS_EXCEPTION_IF_NULL(copy_actor);
  82. (void)actors.emplace_back(static_cast<AbstractActorPtr>(copy_actor));
  83. }
  84. if (actor_set->loop_count_actor_ != nullptr) {
  85. (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->loop_count_actor_));
  86. }
  87. if (actor_set->output_actor_ != nullptr) {
  88. (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->output_actor_));
  89. }
  90. if (actor_set->control_actors_ != nullptr) {
  91. const auto &control_actor_set = actor_set->control_actors_;
  92. for (auto &switch_actor : control_actor_set->switch_actors_) {
  93. MS_EXCEPTION_IF_NULL(switch_actor);
  94. (void)actors.emplace_back(static_cast<AbstractActorPtr>(switch_actor));
  95. }
  96. for (auto &gather_actor : control_actor_set->gather_actors_) {
  97. MS_EXCEPTION_IF_NULL(gather_actor);
  98. (void)actors.emplace_back(static_cast<AbstractActorPtr>(gather_actor));
  99. }
  100. for (auto &entrance_actor : control_actor_set->entrance_actors_) {
  101. MS_EXCEPTION_IF_NULL(entrance_actor);
  102. (void)actors.emplace_back(static_cast<AbstractActorPtr>(entrance_actor));
  103. }
  104. for (auto &exit_actor : control_actor_set->exit_actors_) {
  105. MS_EXCEPTION_IF_NULL(exit_actor);
  106. (void)actors.emplace_back(static_cast<AbstractActorPtr>(exit_actor));
  107. }
  108. for (auto &stack_actor : control_actor_set->stack_actors_) {
  109. MS_EXCEPTION_IF_NULL(stack_actor);
  110. (void)actors.emplace_back(static_cast<AbstractActorPtr>(stack_actor));
  111. }
  112. }
  113. return actors;
  114. }
  115. void ClearNodeInfo(const KernelGraphPtr &graph) {
  116. MS_EXCEPTION_IF_NULL(graph);
  117. // Clear input parameter device tensor and device tensor store.
  118. for (const auto &input_node : graph->input_nodes()) {
  119. MS_EXCEPTION_IF_NULL(input_node);
  120. if (!input_node->isa<Parameter>()) {
  121. continue;
  122. }
  123. auto parameter = input_node->cast<ParameterPtr>();
  124. MS_EXCEPTION_IF_NULL(parameter);
  125. parameter->DecreaseUsedGraphCount();
  126. // Only the parameter has no graph used, then clear the device tensor.
  127. if (parameter->used_graph_count() != 0) {
  128. continue;
  129. }
  130. auto front_input_node = FetchFrontNodeByBackendNode(input_node, graph);
  131. DeviceTensorStore::GetInstance().Remove(front_input_node.get());
  132. size_t output_num = AnfAlgo::GetOutputTensorNum(input_node);
  133. for (size_t index = 0; index < output_num; ++index) {
  134. if (AnfAlgo::OutputAddrExist(input_node, index)) {
  135. AnfAlgo::SetOutputAddr(nullptr, index, input_node.get());
  136. }
  137. }
  138. }
  139. // Clear input value node device tensor and device tensor store.
  140. for (const auto &value_node : graph->graph_value_nodes()) {
  141. auto front_value_node = FetchFrontNodeByBackendNode(value_node, graph);
  142. DeviceTensorStore::GetInstance().Remove(front_value_node.get());
  143. if (AnfAlgo::OutputAddrExist(value_node, 0)) {
  144. AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
  145. }
  146. }
  147. // Clear cnode device tensor.
  148. for (const auto &cnode : graph->execution_order()) {
  149. size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
  150. for (size_t index = 0; index < output_num; ++index) {
  151. if (AnfAlgo::OutputAddrExist(cnode, index)) {
  152. AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
  153. }
  154. }
  155. }
  156. }
  157. #if !defined(_WIN32) && !defined(_WIN64)
  158. void IntHandler(int, siginfo_t *, void *) {
  159. int this_pid = getpid();
  160. MS_LOG(WARNING) << "Process " << this_pid << " receive KeyboardInterrupt signal.";
  161. (void)kill(this_pid, SIGTERM);
  162. }
  163. #endif
  164. } // namespace
  165. void GraphScheduler::Clear(const ActorInfo &actor_info, const std::vector<KernelGraphPtr> &graphs) noexcept {
  166. // Terminate the actors of actor info.
  167. if (actors_.count(actor_info) > 0) {
  168. auto actor_manager = ActorMgr::GetActorMgrRef();
  169. if (actor_manager == nullptr) {
  170. MS_LOG(ERROR) << "Actor manager is not exist.";
  171. return;
  172. }
  173. auto actor_set = actors_[actor_info];
  174. auto base_actors = CollectActors(actor_set.get());
  175. for (auto &base_actor : base_actors) {
  176. MS_EXCEPTION_IF_NULL(base_actor);
  177. EraseActor(base_actor->GetAID().Name());
  178. actor_manager->Terminate(base_actor->GetAID());
  179. }
  180. }
  181. // Clear device tensor and device tensor store.
  182. for (auto &graph : graphs) {
  183. ClearNodeInfo(graph);
  184. }
  185. // Clear global maps of actor info.
  186. (void)actors_.erase(actor_info);
  187. }
  188. void GraphScheduler::Clear() {
  189. // Terminate all actors.
  190. auto actor_manager = ActorMgr::GetActorMgrRef();
  191. MS_EXCEPTION_IF_NULL(actor_manager);
  192. actor_manager->Finalize();
  193. // Clear the member of DeviceTensorStore.
  194. DeviceTensorStore::GetInstance().Clear();
  195. // Clear global maps.
  196. actors_.clear();
  197. ClearAllActors();
  198. }
  199. using DataArrowLinkFunc = void (GraphScheduler::*)(AbstractActor *const, AbstractActor *const, const KernelWithIndex &,
  200. const KernelWithIndex &, const KernelGraphPtr &);
  201. static std::map<KernelTransformType, DataArrowLinkFunc> kKernelTypeToLinkFunc;
  202. void GraphScheduler::Initialize() {
  203. if (init_) {
  204. return;
  205. }
  206. init_ = true;
  207. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kDeviceDataSourceActor,
  208. &GraphScheduler::LinkDataArrowForBaseActor);
  209. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kHostDataSourceActor,
  210. &GraphScheduler::LinkDataArrowForHostDSActor);
  211. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kKernelActor, &GraphScheduler::LinkDataArrowForKernelActor);
  212. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kSuperKernelActor,
  213. &GraphScheduler::LinkDataArrowForBaseActor);
  214. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kDeviceTensorStore,
  215. &GraphScheduler::LinkDataArrowForDeviceTensorStore);
  216. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kInternalParameter,
  217. &GraphScheduler::LinkDataArrowForInternalParameter);
  218. // Create the thread pool of actor runtime and Set the OMP_NUM_THREADS env.
  219. size_t actor_thread_num = 0;
  220. size_t actor_and_kernel_thread_num = 0;
  221. ComputeThreadNums(&actor_thread_num, &actor_and_kernel_thread_num);
  222. auto actor_manager = ActorMgr::GetActorMgrRef();
  223. MS_EXCEPTION_IF_NULL(actor_manager);
  224. auto ret = actor_manager->Initialize(true, actor_thread_num, actor_and_kernel_thread_num);
  225. if (ret != MINDRT_OK) {
  226. MS_LOG(EXCEPTION) << "Actor manager init failed.";
  227. }
  228. (void)common::SetOMPThreadNum();
  229. auto OMP_thread_num_used = common::GetEnv("OMP_NUM_THREADS");
  230. MS_LOG(INFO) << "The actor thread number: " << actor_thread_num
  231. << ", the kernel thread number: " << (actor_and_kernel_thread_num - actor_thread_num)
  232. << ", the used OMP thread number: " << OMP_thread_num_used;
  233. BuildAndScheduleGlobalActor();
  234. }
  235. void GraphScheduler::BuildAndScheduleGlobalActor() {
  236. auto actor_manager = ActorMgr::GetActorMgrRef();
  237. MS_EXCEPTION_IF_NULL(actor_manager);
  238. // Create and schedule memory manager actor.
  239. auto memory_manager_actor = std::make_shared<MemoryManagerActor>();
  240. MS_EXCEPTION_IF_NULL(memory_manager_actor);
  241. memory_manager_aid_ = memory_manager_actor->GetAID();
  242. auto base_actor = static_cast<ActorReference>(memory_manager_actor);
  243. // Bind single thread to response to memory alloc and free quickly.
  244. (void)actor_manager->Spawn(base_actor, false);
  245. // Create and schedule recorder actor.
  246. auto recorder_actor = std::make_shared<RecorderActor>();
  247. MS_EXCEPTION_IF_NULL(recorder_actor);
  248. recorder_aid_ = &(recorder_actor->GetAID());
  249. auto base_recorder_actor = static_cast<ActorReference>(recorder_actor);
  250. (void)actor_manager->Spawn(base_recorder_actor, true);
  251. // Create and schedule debug actor.
  252. #ifndef ENABLE_SECURITY
  253. bool debugger_actor_need = DumpJsonParser::GetInstance().e2e_dump_enabled();
  254. #endif
  255. #ifdef ENABLE_DEBUGGER
  256. if (Debugger::GetInstance()->DebuggerBackendEnabled()) {
  257. debugger_actor_need = true;
  258. }
  259. #endif
  260. #ifndef ENABLE_SECURITY
  261. if (debugger_actor_need) {
  262. auto debug_actor = std::make_shared<DebugActor>();
  263. MS_EXCEPTION_IF_NULL(debug_actor);
  264. debug_aid_ = &(debug_actor->GetAID());
  265. auto base_debug_actor = static_cast<ActorReference>(debug_actor);
  266. (void)actor_manager->Spawn(base_debug_actor, true);
  267. }
  268. #endif
  269. }
  270. ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info) {
  271. MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor begin.";
  272. if (graph_compiler_info.graphs_.size() == 0) {
  273. MS_LOG(EXCEPTION) << "The number of graphs is zero.";
  274. }
  275. if (graph_compiler_info.graphs_.size() != graph_compiler_info.device_contexts_.size()) {
  276. MS_LOG(EXCEPTION) << "The number of graphs is not equal to the number of device contexts.";
  277. }
  278. PersistDeviceTensor(graph_compiler_info);
  279. const auto &actor_set = Build(graph_compiler_info);
  280. MS_EXCEPTION_IF_NULL(actor_set);
  281. CacheGraphOutputToActor(graph_compiler_info);
  282. Link(actor_set.get(), graph_compiler_info);
  283. // The copy actors are built in the link, so need push into the actor set after link.
  284. actor_set->copy_actors_ = copy_actors_;
  285. DumpActor(actor_set.get(), graph_compiler_info);
  286. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
  287. CheckActorValid(actor_set.get());
  288. }
  289. MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor end.";
  290. // Local maps and vectors clear.
  291. graph_output_to_actor_.clear();
  292. copy_actors_.clear();
  293. return actor_set.get();
  294. }
  295. void GraphScheduler::Schedule(const ActorSet *actor_set) {
  296. MS_EXCEPTION_IF_NULL(actor_set);
  297. auto actors = CollectActors(actor_set);
  298. // Schedule actors.
  299. auto actor_manager = ActorMgr::GetActorMgrRef();
  300. MS_EXCEPTION_IF_NULL(actor_manager);
  301. for (auto actor : actors) {
  302. (void)actor_manager->Spawn(actor);
  303. }
  304. }
  305. void GraphScheduler::Run(ActorSet *const actor_set, const std::vector<DeviceContext *> &device_contexts,
  306. const std::vector<std::vector<TensorPtr>> &input_tensors,
  307. const std::vector<TensorPtr> &input_tensors_with_value_node, GraphExecutionStrategy strategy) {
  308. MS_EXCEPTION_IF_NULL(actor_set);
  309. MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
  310. #if !defined(_WIN32) && !defined(_WIN64)
  311. SignalGuard sg(IntHandler);
  312. #endif
  313. // Construct OpContext.
  314. OpContext<DeviceTensor> op_context;
  315. std::vector<Promise<int>> result(1);
  316. op_context.sequential_num_ = RandInt::Instance().Get();
  317. op_context.results_ = &result;
  318. if ((strategy == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {
  319. actor_set->data_prepare_actor_->PrepareData(input_tensors, &op_context);
  320. MS_EXCEPTION_IF_NULL(actor_set->kernel_actors_[0]);
  321. actor_set->kernel_actors_[0]->RunOpControlWithInputTensor(nullptr, &op_context, &input_tensors_with_value_node);
  322. return;
  323. }
  324. // Trigger data prepare actor running.
  325. MS_EXCEPTION_IF_NULL(ActorMgr::GetActorMgrRef());
  326. auto thread_pool = ActorMgr::GetActorMgrRef()->GetActorThreadPool();
  327. MS_EXCEPTION_IF_NULL(thread_pool);
  328. if (actor_set->is_multi_thread_execution_) {
  329. thread_pool->SetSpinCountMaxValue();
  330. }
  331. ActorDispatcher::is_multi_thread_execution(actor_set->is_multi_thread_execution_);
  332. double start_time = GetTime();
  333. ActorDispatcher::Send(actor_set->data_prepare_actor_->GetAID(), &DataPrepareActor::PrepareData, input_tensors,
  334. &op_context);
  335. // Get the run result.
  336. auto result_future = result[0].GetFuture();
  337. result_future.Wait();
  338. MsException::Instance().CheckException();
  339. thread_pool->SetSpinCountMinValue();
  340. if (!result_future.IsOK()) {
  341. #ifdef ENABLE_DUMP_IR
  342. mindspore::RDR::TriggerAll();
  343. #endif
  344. // When temporary variable 'op_context' has beed set failed status, the main thread need wait other threads until
  345. // they finish respective task, otherwise segmentation fault will happen when these task access 'op_context',
  346. // because it has been destroyed.
  347. std::mutex mutex;
  348. std::unique_lock<std::mutex> locker(mutex);
  349. std::condition_variable thread_blocker;
  350. const int64_t kTimeToWait = 2;
  351. thread_blocker.wait_for(locker, std::chrono::seconds(kTimeToWait));
  352. MS_LOG(EXCEPTION) << op_context.error_info_;
  353. }
  354. // Sync device stream.
  355. if (strategy == GraphExecutionStrategy::kPipeline) {
  356. std::set<DeviceContext *> sync_stream_device_contexts;
  357. for (auto &device_context : device_contexts) {
  358. MS_EXCEPTION_IF_NULL(device_context);
  359. if ((sync_stream_device_contexts.count(device_context) == 0) && (!device_context->SyncStream())) {
  360. MS_LOG(EXCEPTION) << "Sync stream failed:" << device_context->device_context_key().ToString();
  361. }
  362. (void)sync_stream_device_contexts.insert(device_context);
  363. }
  364. }
  365. double end_time = GetTime();
  366. const size_t kSecondsToMilliseconds = 1000;
  367. SetActorExecutionStrategy(actor_set, strategy, (end_time - start_time) * kSecondsToMilliseconds);
  368. }
  369. void GraphScheduler::SetActorExecutionStrategy(ActorSet *const actor_set, GraphExecutionStrategy strategy,
  370. double execution_time) {
  371. MS_EXCEPTION_IF_NULL(actor_set);
  372. MS_EXCEPTION_IF_NULL(actor_set->loop_count_actor_);
  373. ++actor_set->execution_count_;
  374. MS_LOG(DEBUG) << "Execution count: " << actor_set->execution_count_ << ", execution time cost: " << execution_time
  375. << " ms in multi thread or not: " << actor_set->is_multi_thread_execution_ << ".";
  376. #if defined(_WIN32) || defined(_WIN64)
  377. return;
  378. #endif
  379. // The step mode uses the default multi thread.
  380. if (strategy == GraphExecutionStrategy::kStep) {
  381. return;
  382. }
  383. if ((actor_set->copy_actors_.size() > 0) || (actor_set->super_kernel_actors_.size() > 0) ||
  384. (actor_set->kernel_actors_.size() > ActorDispatcher::kSingleThreadExecutionActorMaxNum)) {
  385. return;
  386. }
  387. if ((actor_set->is_multi_thread_execution_) &&
  388. (actor_set->execution_count_ >= ActorDispatcher::kMultiThreadExecutionCountBegin) &&
  389. (actor_set->execution_count_ <= ActorDispatcher::kMultiThreadExecutionCountEnd)) {
  390. actor_set->multi_thread_execution_time_ += execution_time;
  391. if (actor_set->execution_count_ == ActorDispatcher::kMultiThreadExecutionCountEnd) {
  392. actor_set->multi_thread_execution_time_ /=
  393. (ActorDispatcher::kMultiThreadExecutionCountEnd - ActorDispatcher::kMultiThreadExecutionCountBegin + 1);
  394. actor_set->multi_thread_execution_time_ /= actor_set->loop_count_actor_->loop_count();
  395. actor_set->is_multi_thread_execution_ = false;
  396. }
  397. return;
  398. }
  399. if ((!actor_set->is_multi_thread_execution_) &&
  400. (actor_set->execution_count_ >= ActorDispatcher::kSingleThreadExecutionCountBegin) &&
  401. (actor_set->execution_count_ <= ActorDispatcher::kSingleThreadExecutionCountEnd)) {
  402. actor_set->single_thread_execution_time_ += execution_time;
  403. if (actor_set->execution_count_ == ActorDispatcher::kSingleThreadExecutionCountEnd) {
  404. actor_set->single_thread_execution_time_ /=
  405. (ActorDispatcher::kSingleThreadExecutionCountEnd - ActorDispatcher::kSingleThreadExecutionCountBegin + 1);
  406. actor_set->single_thread_execution_time_ /= actor_set->loop_count_actor_->loop_count();
  407. actor_set->is_multi_thread_execution_ =
  408. (actor_set->multi_thread_execution_time_ <= actor_set->single_thread_execution_time_) ? true : false;
  409. MS_LOG(INFO) << "Multi thread execution time cost: " << actor_set->multi_thread_execution_time_
  410. << " ms, single thread execution time cost: " << actor_set->single_thread_execution_time_
  411. << " ms, decide to use multi thread execution or not: " << actor_set->is_multi_thread_execution_
  412. << ".";
  413. }
  414. return;
  415. }
  416. }
  417. ActorSet *GraphScheduler::Fetch(const ActorInfo &actor_info) const {
  418. auto iter = actors_.find(actor_info);
  419. if (iter != actors_.end()) {
  420. return iter->second.get();
  421. } else {
  422. MS_LOG(ERROR) << "Can't find the actors map of " << actor_info;
  423. return nullptr;
  424. }
  425. }
  426. ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info) {
  427. auto actor_set = std::make_shared<ActorSet>(graph_compiler_info.name_);
  428. MS_EXCEPTION_IF_NULL(actor_set);
  429. (void)actors_.emplace(actor_set->name_, actor_set);
  430. auto host_queue = std::make_shared<HostTensorQueue>();
  431. actor_set->data_source_actors_ = BuildDataSourceActor(graph_compiler_info, host_queue);
  432. actor_set->kernel_actors_ = BuildKernelActor(graph_compiler_info);
  433. actor_set->super_kernel_actors_ = BuildSuperKernelActor(graph_compiler_info);
  434. actor_set->loop_count_actor_ = BuildLoopCountActor(graph_compiler_info);
  435. actor_set->output_actor_ = BuildOutputActor(graph_compiler_info);
  436. actor_set->data_prepare_actor_ =
  437. BuildDataPrepareActor(graph_compiler_info, actor_set->data_source_actors_, host_queue);
  438. actor_set->control_actors_ = control_node_scheduler_.Build(graph_compiler_info);
  439. return actor_set;
  440. }
  441. void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info) {
  442. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
  443. return;
  444. }
  445. for (const auto &graph : graph_compiler_info.graphs_) {
  446. MS_EXCEPTION_IF_NULL(graph);
  447. auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
  448. for (const auto &output_with_index : outputs) {
  449. auto output_kernel = output_with_index.first;
  450. MS_EXCEPTION_IF_NULL(output_kernel);
  451. auto origin_output_with_index = graph->GetFrontNodeWithIndexByGraphOutput(output_with_index);
  452. if (origin_output_with_index.first == nullptr) {
  453. MS_LOG(WARNING) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
  454. << " with index: " << output_with_index.second << " has no front node.";
  455. continue;
  456. }
  457. auto kernel_type = KernelTransformType::kUnknown;
  458. std::string kernel_name = "";
  459. FetchKernelTransformTypeAndName(output_kernel, graph, graph_compiler_info, &kernel_type, &kernel_name);
  460. if (kernel_name == "") {
  461. MS_LOG(INFO) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
  462. << " with index:" << output_with_index.second
  463. << " is not actor, and the kernel type is:" << kernel_type;
  464. }
  465. auto output_actor = dynamic_cast<AbstractActor *>(FetchActor(kernel_name));
  466. auto output_actor_name = (output_actor != nullptr) ? output_actor->GetAID().Name() : "";
  467. (void)graph_output_to_actor_.emplace(origin_output_with_index, GraphOutputPair(output_actor, output_with_index));
  468. MS_LOG(INFO) << "Cache the graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
  469. << " with index:" << output_with_index.second << " to actor:" << output_actor_name
  470. << ", from front node:" << origin_output_with_index.first->fullname_with_scope()
  471. << " with index:" << origin_output_with_index.second;
  472. }
  473. }
  474. }
  475. void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) {
  476. MS_EXCEPTION_IF_NULL(actor_set);
  477. std::vector<KernelActor *> auto_monad_actors;
  478. std::vector<CNodePtr> communication_nodes;
  479. for (const auto &graph : graph_compiler_info.graphs_) {
  480. MS_EXCEPTION_IF_NULL(graph);
  481. if (graph->is_executing_sink()) {
  482. LinkDataArrowInSinkMode(graph, graph_compiler_info);
  483. } else {
  484. LinkDataArrowInNonSinkMode(graph, graph_compiler_info, &auto_monad_actors, &communication_nodes);
  485. }
  486. }
  487. LinkGlobalControlArrow(actor_set, communication_nodes, auto_monad_actors, graph_compiler_info);
  488. LinkOutputResultArrowForOutputActor(actor_set->output_actor_.get(), graph_compiler_info);
  489. // Link the arrow in the control flow scene.
  490. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline &&
  491. graph_compiler_info.control_node_parser_ != nullptr && graph_compiler_info.control_node_parser_->IsInited()) {
  492. control_node_scheduler_.Link(actor_set, graph_compiler_info);
  493. }
  494. }
  495. std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
  496. const HostTensorQueuePtr &host_queue) {
  497. std::vector<DataSourceActorPtr> data_source_actors;
  498. HostQueueDSActorPtr host_queue_ds_actor = nullptr;
  499. size_t data_node_position = 0;
  500. std::unordered_map<AnfNodePtr, size_t> front_node_position_temp_map;
  501. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  502. const auto &graph = graph_compiler_info.graphs_[i];
  503. const auto &device_context = graph_compiler_info.device_contexts_[i];
  504. MS_EXCEPTION_IF_NULL(graph);
  505. // Build host queue data source actor.
  506. const std::vector<AnfNodePtr> &input_nodes = graph->input_nodes();
  507. for (size_t j = 0; j < input_nodes.size(); j++) {
  508. const auto &input_node = input_nodes[j];
  509. MS_EXCEPTION_IF_NULL(input_node);
  510. if (IsHostQueueDSActor(input_node, graph, graph_compiler_info.origin_parameters_order_,
  511. graph_compiler_info.strategy_)) {
  512. if (host_queue_ds_actor == nullptr) {
  513. auto actor_name = graph_compiler_info.name_ + "_HostDSActor";
  514. MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
  515. host_queue_ds_actor = std::make_shared<HostQueueDataSourceActor>(actor_name, 1, memory_manager_aid_, nullptr,
  516. nullptr, host_queue);
  517. InsertActor(host_queue_ds_actor.get());
  518. (void)data_source_actors.emplace_back(host_queue_ds_actor);
  519. }
  520. const auto &front_node = FetchFrontNodeByBackendNode(input_node, graph);
  521. // In the scenario where multiple backend nodes correspond to the same front node, only the first backend node
  522. // is saved in the host queue data source actor.
  523. if (front_node_position_temp_map.count(front_node) > 0) {
  524. (void)host_queue_ds_actor->data_node_position_map_.emplace(input_node,
  525. front_node_position_temp_map[front_node]);
  526. continue;
  527. }
  528. (void)host_queue_ds_actor->data_nodes_.emplace_back(input_node);
  529. (void)host_queue_ds_actor->device_contexts_.emplace_back(device_context);
  530. (void)host_queue_ds_actor->data_node_position_map_.emplace(input_node, data_node_position);
  531. // In control flow, need to rely on the front node to find the location of the corresponding real parameter.
  532. (void)host_queue_ds_actor->data_node_position_map_.emplace(front_node, data_node_position);
  533. (void)front_node_position_temp_map.emplace(front_node, data_node_position);
  534. data_node_position++;
  535. }
  536. }
  537. // The graph sink mode has no device queue data source actor.
  538. if (!graph->is_executing_sink()) {
  539. // Build device queue data source actor.
  540. const auto &execution_order = graph->execution_order();
  541. const auto &iter =
  542. std::find_if(execution_order.begin(), execution_order.end(), [&graph_compiler_info](const CNodePtr &node) {
  543. return IsDeviceQueueDSActor(node, graph_compiler_info.strategy_);
  544. });
  545. if (iter != execution_order.end()) {
  546. auto actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
  547. MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
  548. auto device_queue_ds_actor = std::make_shared<DeviceQueueDataSourceActor>(
  549. actor_name, 1, device_context, memory_manager_aid_, debug_aid_, recorder_aid_);
  550. MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
  551. InsertActor(device_queue_ds_actor.get());
  552. (void)data_source_actors.emplace_back(device_queue_ds_actor);
  553. device_queue_ds_actor->data_kernel_ = *iter;
  554. device_queue_ds_actor->kernel_info_ = dynamic_cast<device::KernelInfo *>((*iter)->kernel_info());
  555. }
  556. }
  557. }
  558. MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
  559. const auto &front_to_backend_parameter = graph_compiler_info.control_node_parser_->front_to_backend_parameters_;
  560. // Initialize the parameter in the control node, first get all the front parameters in the control node, then find
  561. // the corresponding backend parameter from the map, and insert it into the host data source actor
  562. const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
  563. for (const auto &parameter : control_node_parameters) {
  564. if (IsPersistentDeviceTensor(parameter)) {
  565. continue;
  566. }
  567. auto backend_iter = front_to_backend_parameter.find({parameter, 0});
  568. if (backend_iter == front_to_backend_parameter.end() || backend_iter->second.empty()) {
  569. MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(parameter);
  570. }
  571. if (host_queue_ds_actor == nullptr) {
  572. auto actor_name = graph_compiler_info.name_ + "_HostDSActor";
  573. MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
  574. host_queue_ds_actor =
  575. std::make_shared<HostQueueDataSourceActor>(actor_name, 1, memory_manager_aid_, nullptr, nullptr, host_queue);
  576. InsertActor(host_queue_ds_actor.get());
  577. (void)data_source_actors.emplace_back(host_queue_ds_actor);
  578. }
  579. if (host_queue_ds_actor->data_node_position_map_.find(parameter) !=
  580. host_queue_ds_actor->data_node_position_map_.end()) {
  581. continue;
  582. }
  583. const auto &backend_node = backend_iter->second.begin()->first;
  584. auto iter = find(host_queue_ds_actor->data_nodes_.begin(), host_queue_ds_actor->data_nodes_.end(), backend_node);
  585. if (iter != host_queue_ds_actor->data_nodes_.end()) {
  586. (void)host_queue_ds_actor->data_node_position_map_.emplace(parameter,
  587. iter - host_queue_ds_actor->data_nodes_.begin());
  588. } else {
  589. (void)host_queue_ds_actor->data_node_position_map_.emplace(parameter, host_queue_ds_actor->data_nodes_.size());
  590. (void)host_queue_ds_actor->data_node_position_map_.emplace(backend_iter->second.begin()->first,
  591. host_queue_ds_actor->data_nodes_.size());
  592. (void)host_queue_ds_actor->data_nodes_.emplace_back(backend_iter->second.begin()->first);
  593. (void)host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.begin()->second);
  594. }
  595. }
  596. return data_source_actors;
  597. }
  598. std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompilerInfo &graph_compiler_info) {
  599. std::vector<KernelActorPtr> kernel_actors;
  600. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  601. const auto &graph = graph_compiler_info.graphs_[i];
  602. const auto &device_context = graph_compiler_info.device_contexts_[i];
  603. MS_EXCEPTION_IF_NULL(graph);
  604. if (graph->is_executing_sink()) {
  605. continue;
  606. }
  607. auto execution_order = graph->execution_order();
  608. // Single op graph in step mode, kernel actor executes synchronously.
  609. bool is_single_op_graph = execution_order.size() == 1;
  610. GraphExecutionStrategy strategy = graph_compiler_info.strategy_;
  611. if (strategy == GraphExecutionStrategy::kStep) {
  612. strategy = (is_single_op_graph ? strategy : GraphExecutionStrategy::kPipeline);
  613. }
  614. for (auto &kernel : execution_order) {
  615. MS_EXCEPTION_IF_NULL(kernel);
  616. if (IsKernelActor(kernel, graph_compiler_info.strategy_) && (!IsSkippedKernelActor(kernel))) {
  617. auto kernel_actor = std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context,
  618. memory_manager_aid_, debug_aid_, recorder_aid_, strategy);
  619. MS_EXCEPTION_IF_NULL(kernel_actor);
  620. InsertActor(kernel_actor.get());
  621. (void)kernel_actors.emplace_back(kernel_actor);
  622. }
  623. }
  624. }
  625. return kernel_actors;
  626. }
  627. std::vector<SuperKernelActorPtr> GraphScheduler::BuildSuperKernelActor(const GraphCompilerInfo &graph_compiler_info) {
  628. std::vector<SuperKernelActorPtr> super_kernel_actors;
  629. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  630. const auto &graph = graph_compiler_info.graphs_[i];
  631. const auto &device_context = graph_compiler_info.device_contexts_[i];
  632. MS_EXCEPTION_IF_NULL(graph);
  633. if (!graph->is_executing_sink()) {
  634. continue;
  635. }
  636. auto actor_name = graph->ToString() + "_SuperKernelActor";
  637. auto super_kernel_actor =
  638. std::make_shared<SuperKernelActor>(actor_name, graph, device_context, memory_manager_aid_, nullptr, nullptr);
  639. MS_EXCEPTION_IF_NULL(super_kernel_actor);
  640. InsertActor(super_kernel_actor.get());
  641. (void)super_kernel_actors.emplace_back(super_kernel_actor);
  642. }
  643. return super_kernel_actors;
  644. }
  645. LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info) {
  646. auto actor_set = Fetch(graph_compiler_info.name_);
  647. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {
  648. return nullptr;
  649. }
  650. auto loop_count = ConfigManager::GetInstance().iter_num();
  651. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) ||
  652. (graph_compiler_info.graphs_.size() == 1 && graph_compiler_info.graphs_[0]->is_loop_count_sink())) {
  653. loop_count = 1;
  654. }
  655. auto actor_name = graph_compiler_info.name_ + "_LoopCountActor";
  656. auto loop_count_actor =
  657. std::make_shared<LoopCountActor>(actor_name, loop_count, memory_manager_aid_, debug_aid_, recorder_aid_);
  658. MS_LOG(INFO) << "Create loop count actor: " << actor_name;
  659. MS_EXCEPTION_IF_NULL(loop_count_actor);
  660. InsertActor(loop_count_actor.get());
  661. return loop_count_actor;
  662. }
  663. OutputActorPtr GraphScheduler::BuildOutputActor(const GraphCompilerInfo &graph_compiler_info) {
  664. auto actor_set = Fetch(graph_compiler_info.name_);
  665. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {
  666. return nullptr;
  667. }
  668. auto loop_count = ConfigManager::GetInstance().iter_num();
  669. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) ||
  670. (graph_compiler_info.graphs_.size() == 1 && graph_compiler_info.graphs_[0]->is_loop_count_sink())) {
  671. loop_count = 1;
  672. }
  673. auto actor_name = graph_compiler_info.name_ + "_" + "OutputActor";
  674. auto output_actor = std::make_shared<OutputActor>(actor_name, loop_count, graph_compiler_info.outputs_num_);
  675. MS_LOG(INFO) << "Create output actor: " << actor_name;
  676. MS_EXCEPTION_IF_NULL(output_actor);
  677. InsertActor(output_actor.get());
  678. return output_actor;
  679. }
  680. DataPrepareActorPtr GraphScheduler::BuildDataPrepareActor(const GraphCompilerInfo &graph_compiler_info,
  681. const std::vector<DataSourceActorPtr> &data_source_actors,
  682. const HostTensorQueuePtr &host_queue) {
  683. HostQueueDSActorPtr host_queue_ds_actor = nullptr;
  684. auto iter = std::find_if(data_source_actors.begin(), data_source_actors.end(), [&](const auto &data_source_actor) {
  685. return data_source_actor->type_ == KernelTransformType::kHostDataSourceActor;
  686. });
  687. if (iter != data_source_actors.end()) {
  688. host_queue_ds_actor = std::dynamic_pointer_cast<HostQueueDataSourceActor>(*iter);
  689. }
  690. auto actor_name = graph_compiler_info.name_ + "_DataPrepareActor";
  691. auto data_prepare_actor = std::make_shared<DataPrepareActor>(actor_name, memory_manager_aid_, debug_aid_,
  692. &graph_compiler_info, host_queue_ds_actor, host_queue);
  693. MS_LOG(INFO) << "Create data prepare actor: " << actor_name;
  694. MS_EXCEPTION_IF_NULL(data_prepare_actor);
  695. // Cache the nodes which need continuous memory.
  696. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
  697. for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) {
  698. const auto &graph = graph_compiler_info.graphs_[index];
  699. MS_EXCEPTION_IF_NULL(graph);
  700. if (graph->is_executing_sink()) {
  701. continue;
  702. }
  703. auto &execution_order = graph->execution_order();
  704. for (auto &kernel : execution_order) {
  705. if (!AnfAlgo::IsCommunicationOp(kernel)) {
  706. continue;
  707. }
  708. auto key = std::make_pair(kernel, graph_compiler_info.device_contexts_[index]);
  709. auto value = std::make_pair(false, false);
  710. if (AnfAlgo::GetInputTensorNum(kernel) > 1) {
  711. value.first = true;
  712. }
  713. if (AnfAlgo::GetOutputTensorNum(kernel) > 1) {
  714. value.second = true;
  715. }
  716. if ((value.first == true) || (value.second == true)) {
  717. data_prepare_actor->continuous_memory_nodes_[key] = value;
  718. }
  719. }
  720. }
  721. }
  722. InsertActor(data_prepare_actor.get());
  723. return data_prepare_actor;
  724. }
  725. std::vector<AbstractActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorSet *actor_set,
  726. GraphExecutionStrategy strategy) {
  727. MS_EXCEPTION_IF_NULL(actor_set);
  728. std::vector<AbstractActorPtr> no_input_kernel_actors;
  729. for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
  730. MS_EXCEPTION_IF_NULL(super_kernel_actor);
  731. if ((super_kernel_actor->input_datas_num_ == 0) && (super_kernel_actor->input_controls_num_ == 0)) {
  732. (void)no_input_kernel_actors.emplace_back(super_kernel_actor);
  733. }
  734. }
  735. for (auto &kernel_actor : actor_set->kernel_actors_) {
  736. MS_EXCEPTION_IF_NULL(kernel_actor);
  737. // Framework will trigger kernel actor running in the step execution strategy.
  738. if (strategy == GraphExecutionStrategy::kStep && IsSingleOpActorSet(actor_set)) {
  739. kernel_actor->input_controls_num_++;
  740. continue;
  741. }
  742. if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) {
  743. // Check whether the kernel actor belongs to the root graph.
  744. // In general, all no input nodes belong to the root funcgraph, and the corresponding gather actor should be
  745. // empty. In control flow, the control arrow of the no input node in the sub funcgraph should be sent by the
  746. // gather actor and should not be placed in the no input list.
  747. MS_EXCEPTION_IF_NULL(kernel_actor->kernel_);
  748. const auto &graph = kernel_actor->kernel_->func_graph();
  749. if (graph != nullptr) {
  750. const auto &kernel_graph = dynamic_cast<KernelGraph *>(graph.get());
  751. MS_EXCEPTION_IF_NULL(kernel_graph);
  752. const auto func_graph = kernel_graph->GetFuncGraph();
  753. if (func_graph != nullptr && FetchActor(func_graph->ToString()) != nullptr) {
  754. continue;
  755. }
  756. }
  757. (void)no_input_kernel_actors.emplace_back(kernel_actor);
  758. }
  759. }
  760. return no_input_kernel_actors;
  761. }
  762. void GraphScheduler::LinkDataArrowInSinkMode(const KernelGraphPtr &graph,
  763. const GraphCompilerInfo &graph_compiler_info) {
  764. MS_EXCEPTION_IF_NULL(graph);
  765. auto to_actor_name = graph->ToString() + "_SuperKernelActor";
  766. auto to_actor = dynamic_cast<SuperKernelActor *>(FetchActor(to_actor_name));
  767. MS_EXCEPTION_IF_NULL(to_actor);
  768. auto &input_nodes = graph->input_nodes();
  769. for (size_t node_index = 0; node_index < input_nodes.size(); ++node_index) {
  770. auto &input_node = input_nodes[node_index];
  771. MS_EXCEPTION_IF_NULL(input_node);
  772. UpdateRefCount(input_node, 0, true);
  773. auto kernel_type = KernelTransformType::kUnknown;
  774. std::string kernel_name = "";
  775. FetchKernelTransformTypeAndName(input_node, graph, graph_compiler_info, &kernel_type, &kernel_name);
  776. if (kernel_type == KernelTransformType::kDeviceTensorStore) {
  777. continue;
  778. }
  779. auto from_actor = dynamic_cast<AbstractActor *>(FetchActor(kernel_name));
  780. KernelWithIndex from_kernel_with_output_idx = std::make_pair(input_node, 0);
  781. KernelWithIndex to_kernel_with_input_idx = std::make_pair(input_node, node_index);
  782. if (kKernelTypeToLinkFunc.count(kernel_type) == 0) {
  783. MS_LOG(EXCEPTION) << "Invalid from node:" << input_node->fullname_with_scope() << ", type:" << kernel_type;
  784. }
  785. (this->*kKernelTypeToLinkFunc[kernel_type])(from_actor, to_actor, from_kernel_with_output_idx,
  786. to_kernel_with_input_idx, graph);
  787. }
  788. }
  789. void GraphScheduler::LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph,
  790. const GraphCompilerInfo &graph_compiler_info,
  791. std::vector<KernelActor *> *const auto_monad_actors,
  792. std::vector<CNodePtr> *const communication_nodes) {
  793. MS_EXCEPTION_IF_NULL(graph);
  794. MS_EXCEPTION_IF_NULL(auto_monad_actors);
  795. MS_EXCEPTION_IF_NULL(communication_nodes);
  796. const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {
  797. prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
  798. auto &execution_order = graph->execution_order();
  799. // Foreach the execution order to link the actors.
  800. for (const auto &kernel : execution_order) {
  801. MS_EXCEPTION_IF_NULL(kernel);
  802. if (AnfAlgo::IsCommunicationOp(kernel)) {
  803. (void)communication_nodes->emplace_back(kernel);
  804. }
  805. if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) {
  806. continue;
  807. }
  808. const auto &kernel_actor = dynamic_cast<KernelActor *>(FetchActor(kernel->fullname_with_scope()));
  809. MS_EXCEPTION_IF_NULL(kernel_actor);
  810. for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
  811. auto input_node = AnfAlgo::GetInputNode(kernel, i);
  812. // Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
  813. if (IsOneOfPrimitiveCNode(input_node, auto_monad_prims) || HasAbstractMonad(input_node)) {
  814. LinkControlArrowByAutoMonad(kernel_actor, input_node, graph);
  815. }
  816. if (HasAbstractMonad(input_node)) {
  817. (void)auto_monad_actors->emplace_back(kernel_actor);
  818. continue; // No data arrow for monad input.
  819. }
  820. KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
  821. KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);
  822. // The gather of linking data arrows of kernel by the different from kernel type.
  823. LinkDataArrow(kernel_actor, graph_compiler_info, graph, from_kernel_with_output_idx, to_kernel_with_input_idx);
  824. }
  825. }
  826. // Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
  827. LinkControlArrowBySendRecvNodes(graph);
  828. }
  829. void GraphScheduler::LinkDataArrow(KernelActor *const to_actor, const GraphCompilerInfo &graph_compiler_info,
  830. const KernelGraphPtr &graph, const KernelWithIndex &from_kernel_with_output_idx,
  831. const KernelWithIndex &to_kernel_with_input_idx) {
  832. MS_EXCEPTION_IF_NULL(to_actor);
  833. MS_EXCEPTION_IF_NULL(graph);
  834. auto from_kernel = from_kernel_with_output_idx.first;
  835. MS_EXCEPTION_IF_NULL(from_kernel);
  836. if (graph_compiler_info.control_node_parser_ != nullptr &&
  837. graph_compiler_info.control_node_parser_->IsControlFlowDataArrow(graph, from_kernel)) {
  838. return;
  839. }
  840. auto kernel_type = KernelTransformType::kUnknown;
  841. std::string kernel_name = "";
  842. FetchKernelTransformTypeAndName(from_kernel, graph, graph_compiler_info, &kernel_type, &kernel_name);
  843. auto from_actor = dynamic_cast<AbstractActor *>(FetchActor(kernel_name));
  844. if (kKernelTypeToLinkFunc.count(kernel_type) > 0) {
  845. (this->*kKernelTypeToLinkFunc[kernel_type])(from_actor, to_actor, from_kernel_with_output_idx,
  846. to_kernel_with_input_idx, graph);
  847. }
  848. }
  849. void GraphScheduler::LinkDataArrowForDeviceTensorStore(AbstractActor *const, AbstractActor *const to_actor,
  850. const KernelWithIndex &from_kernel_with_output_idx,
  851. const KernelWithIndex &to_kernel_with_input_idx,
  852. const KernelGraphPtr &graph) {
  853. MS_EXCEPTION_IF_NULL(to_actor);
  854. MS_EXCEPTION_IF_NULL(graph);
  855. if (to_actor->type_ == KernelTransformType::kSuperKernelActor) {
  856. return;
  857. }
  858. auto from_kernel = from_kernel_with_output_idx.first;
  859. MS_EXCEPTION_IF_NULL(from_kernel);
  860. auto device_tensor_store_key = FetchFrontNodeByBackendNode(from_kernel, graph);
  861. (void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, device_tensor_store_key);
  862. }
  863. void GraphScheduler::LinkDataArrowForInternalParameter(AbstractActor *const, AbstractActor *to_actor,
  864. const KernelWithIndex &from_kernel_with_output_idx,
  865. const KernelWithIndex &to_kernel_with_input_idx,
  866. const KernelGraphPtr &graph) {
  867. MS_EXCEPTION_IF_NULL(to_actor);
  868. MS_EXCEPTION_IF_NULL(graph);
  869. auto internal_parameter = from_kernel_with_output_idx.first;
  870. MS_EXCEPTION_IF_NULL(internal_parameter);
  871. // Parameter ---> front node.
  872. auto front_output_with_index = graph->GetFrontNodeByInternalParameter(internal_parameter);
  873. auto front_output_node = front_output_with_index.first;
  874. MS_EXCEPTION_IF_NULL(front_output_node);
  875. if (IsSwitchActor(front_output_node)) {
  876. return;
  877. }
  878. auto real_from_kernel_with_output_idx = from_kernel_with_output_idx;
  879. AbstractActor *real_from_actor = nullptr;
  880. KernelTransformType kernel_type;
  881. if (IsPersistentDeviceTensor(front_output_node)) {
  882. kernel_type = KernelTransformType::kDeviceTensorStore;
  883. } else {
  884. // front node ---> actor.
  885. if (graph_output_to_actor_.count(front_output_with_index) == 0) {
  886. MS_LOG(EXCEPTION) << "Can't find actor by front node:" << AnfAlgo::GetNodeDebugString(front_output_node)
  887. << ", internal parameter:" << AnfAlgo::GetNodeDebugString(internal_parameter);
  888. }
  889. auto actor_pair = graph_output_to_actor_[front_output_with_index];
  890. MS_EXCEPTION_IF_NULL(actor_pair.first);
  891. MS_EXCEPTION_IF_NULL(actor_pair.second.first);
  892. MS_LOG(INFO) << "Graph " << graph->graph_id() << " internal parameter:" << internal_parameter->DebugString()
  893. << ", corresponding front node:" << front_output_node->fullname_with_scope()
  894. << " with index:" << front_output_with_index.second
  895. << ", from actor:" << actor_pair.first->GetAID().Name()
  896. << " node:" << actor_pair.second.first->fullname_with_scope()
  897. << " with index:" << actor_pair.second.second << ", to actor:" << to_actor->GetAID().Name()
  898. << " with index:" << to_kernel_with_input_idx.second;
  899. real_from_actor = actor_pair.first;
  900. real_from_kernel_with_output_idx = actor_pair.second;
  901. kernel_type = actor_pair.first->type_;
  902. }
  903. if (kKernelTypeToLinkFunc.count(kernel_type) == 0) {
  904. MS_LOG(EXCEPTION) << "Invalid internal parameter:" << internal_parameter->DebugString() << ", type:" << kernel_type;
  905. }
  906. (this->*kKernelTypeToLinkFunc[kernel_type])(real_from_actor, to_actor, real_from_kernel_with_output_idx,
  907. to_kernel_with_input_idx, graph);
  908. }
  909. void GraphScheduler::LinkDataArrowForBaseActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  910. const KernelWithIndex &from_kernel_with_output_idx,
  911. const KernelWithIndex &to_kernel_with_input_idx,
  912. const KernelGraphPtr &) {
  913. MS_EXCEPTION_IF_NULL(from_actor);
  914. MS_EXCEPTION_IF_NULL(to_actor);
  915. auto from_kernel = from_kernel_with_output_idx.first;
  916. MS_EXCEPTION_IF_NULL(from_kernel);
  917. auto from_output_index = from_kernel_with_output_idx.second;
  918. auto to_input_index = to_kernel_with_input_idx.second;
  919. // Get the position of from kernel in the data source actor.
  920. auto position = from_actor->FetchNodePosition(from_kernel);
  921. if ((from_actor->device_contexts_.size() <= position) || (to_actor->device_contexts_.size() <= 0)) {
  922. MS_LOG(EXCEPTION) << "The device contexts size is wrong.";
  923. }
  924. if (IsNeedInsertCopyActor(from_actor->device_contexts_[position], to_actor->device_contexts_[0])) {
  925. LinkDataArrowForCopyActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
  926. } else {
  927. AddDataArrow(from_actor, to_actor, from_kernel, from_output_index, to_input_index);
  928. }
  929. }
  930. void GraphScheduler::LinkDataArrowForHostDSActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  931. const KernelWithIndex &from_kernel_with_output_idx,
  932. const KernelWithIndex &to_kernel_with_input_idx,
  933. const KernelGraphPtr &graph) {
  934. auto host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
  935. MS_EXCEPTION_IF_NULL(host_ds_actor);
  936. MS_EXCEPTION_IF_NULL(from_kernel_with_output_idx.first);
  937. KernelWithIndex real_from_kernel_with_output_idx = from_kernel_with_output_idx;
  938. // Get the position and real kernel by from kernel in the data source actor.
  939. auto position = host_ds_actor->FetchNodePosition(from_kernel_with_output_idx.first);
  940. real_from_kernel_with_output_idx.first = host_ds_actor->FetchNode(position);
  941. LinkDataArrowForBaseActor(from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx, graph);
  942. }
  943. void GraphScheduler::LinkDataArrowForKernelActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  944. const KernelWithIndex &from_kernel_with_output_idx,
  945. const KernelWithIndex &to_kernel_with_input_idx,
  946. const KernelGraphPtr &graph) {
  947. auto real_from_actor = from_actor;
  948. auto real_from_kernel_with_output_idx = from_kernel_with_output_idx;
  949. auto from_kernel = from_kernel_with_output_idx.first;
  950. // Update the from kernel info by the real node info.
  951. MS_EXCEPTION_IF_NULL(from_kernel);
  952. if (IsSkippedKernelActor(from_kernel)) {
  953. real_from_kernel_with_output_idx = AnfAlgo::GetPrevNodeOutput(from_kernel, 0, false);
  954. MS_EXCEPTION_IF_NULL(real_from_kernel_with_output_idx.first);
  955. LinkControlArrowBySkippedNode(to_actor, from_kernel);
  956. MS_EXCEPTION_IF_NULL(to_kernel_with_input_idx.first);
  957. MS_LOG(INFO) << "Link data arrow for inplace node, aggregate node: "
  958. << to_kernel_with_input_idx.first->fullname_with_scope()
  959. << ", aggregate input index: " << to_kernel_with_input_idx.second
  960. << ", skip node: " << from_kernel->fullname_with_scope()
  961. << ", real node: " << real_from_kernel_with_output_idx.first->fullname_with_scope();
  962. real_from_actor =
  963. dynamic_cast<AbstractActor *>(FetchActor(real_from_kernel_with_output_idx.first->fullname_with_scope()));
  964. MS_EXCEPTION_IF_NULL(real_from_actor);
  965. }
  966. LinkDataArrowForBaseActor(real_from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx,
  967. graph);
  968. }
  969. void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  970. const KernelWithIndex &from_kernel_with_output_idx,
  971. const KernelWithIndex &to_kernel_with_input_idx) {
  972. MS_EXCEPTION_IF_NULL(from_actor);
  973. MS_EXCEPTION_IF_NULL(to_actor);
  974. auto from_kernel = from_kernel_with_output_idx.first;
  975. MS_EXCEPTION_IF_NULL(from_kernel);
  976. std::string name = "copy_from:" + from_actor->GetAID().Name() + "_node:" + from_kernel->fullname_with_scope() +
  977. "_output_index:" + std::to_string(from_kernel_with_output_idx.second);
  978. CopyActor *copy_actor = dynamic_cast<CopyActor *>(FetchActor(name));
  979. // Link between from actor and copy actor.
  980. if (copy_actor == nullptr) {
  981. // Create the copy actor.
  982. auto copy_actor_shared_ptr = std::make_shared<CopyActor>(name, memory_manager_aid_);
  983. (void)copy_actors_.emplace_back(copy_actor_shared_ptr);
  984. copy_actor = copy_actor_shared_ptr.get();
  985. MS_EXCEPTION_IF_NULL(copy_actor);
  986. InsertActor(copy_actor);
  987. // Set the member device_contexts_ of the copy actor.
  988. auto position = from_actor->FetchNodePosition(from_kernel);
  989. if ((from_actor->device_contexts_.size() <= position) || (to_actor->device_contexts_.size() <= 0)) {
  990. MS_LOG(EXCEPTION) << "The device contexts size is wrong.";
  991. }
  992. auto from_device_context = from_actor->device_contexts_[position];
  993. auto to_device_context = to_actor->device_contexts_[0];
  994. MS_EXCEPTION_IF_NULL(from_device_context);
  995. MS_EXCEPTION_IF_NULL(to_device_context);
  996. (void)copy_actor->device_contexts_.emplace_back(from_device_context);
  997. (void)copy_actor->device_contexts_.emplace_back(to_device_context);
  998. // Set the member output_ of the copy actor.
  999. if (to_actor->type_ == KernelTransformType::kSuperKernelActor) {
  1000. copy_actor->output_ = AnfAlgo::GetMutableOutputAddr(to_kernel_with_input_idx.first, 0, false);
  1001. } else {
  1002. copy_actor->output_ =
  1003. AnfAlgo::GetPrevNodeMutableOutputAddr(to_kernel_with_input_idx.first, to_kernel_with_input_idx.second, false);
  1004. }
  1005. MS_EXCEPTION_IF_NULL(copy_actor->output_);
  1006. if (copy_actor->output_->DeviceType() != to_device_context->GetDeviceAddressType()) {
  1007. MS_LOG(EXCEPTION) << "The device type is not equal, output device type:" << copy_actor->output_->DeviceType()
  1008. << ", to device context type:" << to_device_context->GetDeviceAddressType();
  1009. }
  1010. // Link between from actor and copy actor.
  1011. AddDataArrow(from_actor, copy_actor, from_kernel, from_kernel_with_output_idx.second, 0);
  1012. }
  1013. // If the copy actor already exists, only need link between copy actor and to actor.
  1014. AddDataArrow(copy_actor, to_actor, nullptr, 0, to_kernel_with_input_idx.second);
  1015. if (to_actor->type_ == KernelTransformType::kSuperKernelActor) {
  1016. UpdateRefCount(copy_actor->output_.get(), true);
  1017. } else {
  1018. UpdateRefCount(copy_actor->output_.get(), false);
  1019. }
  1020. }
  1021. void GraphScheduler::LinkControlArrowByAutoMonad(AbstractActor *to_actor, const AnfNodePtr &from_node,
  1022. const KernelGraphPtr &graph) {
  1023. MS_EXCEPTION_IF_NULL(to_actor);
  1024. MS_EXCEPTION_IF_NULL(from_node);
  1025. MS_EXCEPTION_IF_NULL(graph);
  1026. // Find the real input node, include the monad node and make tuple node.
  1027. const std::vector<PrimitivePtr> return_types = {prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad,
  1028. prim::kPrimMakeTuple};
  1029. const auto &input_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(from_node, 0, false, return_types);
  1030. MS_EXCEPTION_IF_NULL(input_kernel_with_output_idx.first);
  1031. auto input_anfnode = input_kernel_with_output_idx.first;
  1032. CNodePtr input_cnode = nullptr;
  1033. if (input_anfnode->isa<CNode>()) {
  1034. input_cnode = input_anfnode->cast<CNodePtr>();
  1035. }
  1036. // Make tuple node needs to be expanded.
  1037. if (AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimMakeTuple)) {
  1038. MS_EXCEPTION_IF_NULL(input_cnode);
  1039. for (size_t i = 1; i < input_cnode->inputs().size(); ++i) {
  1040. LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), graph);
  1041. }
  1042. return;
  1043. }
  1044. const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> recursion_prims = {
  1045. prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimMakeTuple};
  1046. // Get the real depend input by monad node which needs to link the control arrow.
  1047. std::vector<AnfNodePtr> real_depend_inputs;
  1048. if (AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimDepend) ||
  1049. AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimLoad)) {
  1050. MS_EXCEPTION_IF_NULL(input_cnode);
  1051. real_depend_inputs.push_back(input_cnode->input(kDependAttachNodeIndex));
  1052. // The real input may be this scene: depend/load --> load/depend, so need add the control arrow for real input
  1053. // node in this scene.
  1054. if (IsOneOfPrimitiveCNode(input_cnode->input(kRealInputIndexInDepend), recursion_prims)) {
  1055. real_depend_inputs.push_back(input_cnode->input(kRealInputIndexInDepend));
  1056. }
  1057. } else if (AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimUpdateState)) {
  1058. MS_EXCEPTION_IF_NULL(input_cnode);
  1059. for (size_t i = kUpdateStateRealInput; i < input_cnode->inputs().size(); ++i) {
  1060. real_depend_inputs.push_back(input_cnode->input(i));
  1061. }
  1062. } else {
  1063. real_depend_inputs.push_back(input_anfnode);
  1064. }
  1065. for (const auto &real_depend_input : real_depend_inputs) {
  1066. auto real_depend_input_with_idx = AnfAlgo::VisitKernelWithReturnType(real_depend_input, 0, false, return_types);
  1067. MS_EXCEPTION_IF_NULL(real_depend_input_with_idx.first);
  1068. auto real_depend_kernel = real_depend_input_with_idx.first;
  1069. // Update the real depend kernel in the subgraphs connecting scene.
  1070. if (IsInternalParameter(real_depend_kernel, graph)) {
  1071. auto front_output_with_index = graph->GetFrontNodeByInternalParameter(real_depend_kernel);
  1072. MS_EXCEPTION_IF_NULL(front_output_with_index.first);
  1073. if (graph_output_to_actor_.count(front_output_with_index) == 0) {
  1074. if (AnfAlgo::IsCallNode(front_output_with_index.first)) {
  1075. continue;
  1076. }
  1077. MS_LOG(EXCEPTION) << "Can't find graph output by front node:" << front_output_with_index.first->DebugString();
  1078. }
  1079. real_depend_kernel = graph_output_to_actor_[front_output_with_index].second.first;
  1080. MS_EXCEPTION_IF_NULL(real_depend_kernel);
  1081. MS_LOG(INFO) << "The graph " << graph->graph_id() << " link control arrow by auto monad from internal parameter: "
  1082. << real_depend_input_with_idx.first->DebugString()
  1083. << ", front output node: " << front_output_with_index.first->fullname_with_scope()
  1084. << ", backend output node: " << real_depend_kernel->fullname_with_scope();
  1085. }
  1086. // The monad node and make tuple node need recursion.
  1087. if (IsOneOfPrimitiveCNode(real_depend_kernel, recursion_prims)) {
  1088. LinkControlArrowByAutoMonad(to_actor, real_depend_kernel, graph);
  1089. continue;
  1090. }
  1091. auto from_actor = dynamic_cast<KernelActor *>(FetchActor(real_depend_kernel->fullname_with_scope()));
  1092. if (from_actor == nullptr) {
  1093. MS_LOG(DEBUG) << "Link control arrow by auto monad from depend node:" << real_depend_kernel->fullname_with_scope()
  1094. << " is not actor for the graph: " << graph->graph_id();
  1095. continue;
  1096. }
  1097. MS_LOG(INFO) << "Link control arrow by auto monad from actor: " << from_actor->GetAID().Name()
  1098. << ", to actor: " << to_actor->GetAID().Name() << " for the graph: " << graph->graph_id();
  1099. AddControlArrow(from_actor, to_actor);
  1100. }
  1101. }
  1102. void GraphScheduler::LinkControlArrowBySkippedNode(AbstractActor *to_actor, const AnfNodePtr &skipped_node) {
  1103. MS_EXCEPTION_IF_NULL(to_actor);
  1104. MS_EXCEPTION_IF_NULL(skipped_node);
  1105. // Link the control arrow from all the inputs of skipped node to the user of skipped node.
  1106. auto input_num = AnfAlgo::GetInputTensorNum(skipped_node);
  1107. for (size_t i = 0; i < input_num; ++i) {
  1108. auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(skipped_node, i, false);
  1109. MS_EXCEPTION_IF_NULL(kernel_with_index.first);
  1110. auto from_actor = dynamic_cast<KernelActor *>(FetchActor(kernel_with_index.first->fullname_with_scope()));
  1111. MS_EXCEPTION_IF_NULL(from_actor);
  1112. MS_LOG(INFO) << "Link control arrow by skipped node: " << skipped_node->fullname_with_scope()
  1113. << ", from actor: " << from_actor->GetAID().Name() << ", to actor: " << to_actor->GetAID().Name();
  1114. AddControlArrow(from_actor, to_actor);
  1115. }
  1116. }
  1117. void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph) {
  1118. MS_EXCEPTION_IF_NULL(graph);
  1119. for (auto &from_iter : graph->allreduce_from_send_recv_pairs()) {
  1120. auto to_allreduce_node = from_iter.first;
  1121. auto from_send_node = from_iter.second.first;
  1122. auto from_recv_node = from_iter.second.second;
  1123. MS_EXCEPTION_IF_NULL(to_allreduce_node);
  1124. MS_EXCEPTION_IF_NULL(from_send_node);
  1125. MS_EXCEPTION_IF_NULL(from_recv_node);
  1126. MS_LOG(INFO) << "Link control arrow for to_allreduce_node: " << to_allreduce_node->fullname_with_scope();
  1127. auto to_allreduce_actor = dynamic_cast<KernelActor *>(FetchActor(to_allreduce_node->fullname_with_scope()));
  1128. auto from_send_actor = dynamic_cast<KernelActor *>(FetchActor(from_send_node->fullname_with_scope()));
  1129. auto from_recv_actor = dynamic_cast<KernelActor *>(FetchActor(from_recv_node->fullname_with_scope()));
  1130. MS_EXCEPTION_IF_NULL(to_allreduce_actor);
  1131. MS_EXCEPTION_IF_NULL(from_send_actor);
  1132. MS_EXCEPTION_IF_NULL(from_recv_actor);
  1133. // inputs of to_allreduce_actor --> from_send_actor
  1134. for (auto &input_aid : to_allreduce_actor->input_data_arrow_aids_) {
  1135. auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
  1136. if (input_actor != nullptr) {
  1137. AddControlArrow(input_actor, from_send_actor);
  1138. }
  1139. }
  1140. // from_send_actor --> from_recv_actor
  1141. AddControlArrow(from_send_actor, from_recv_actor);
  1142. // from_recv_actor --> to_allreduce_actor
  1143. AddControlArrow(from_recv_actor, to_allreduce_actor);
  1144. }
  1145. for (auto &to_iter : graph->allreduce_to_send_recv_pairs()) {
  1146. auto from_allreduce_node = to_iter.first;
  1147. auto to_send_node = to_iter.second.first;
  1148. auto to_recv_node = to_iter.second.second;
  1149. MS_EXCEPTION_IF_NULL(from_allreduce_node);
  1150. MS_EXCEPTION_IF_NULL(to_send_node);
  1151. MS_EXCEPTION_IF_NULL(to_recv_node);
  1152. MS_LOG(INFO) << "Link control arrow for from_allreduce_node: " << from_allreduce_node->fullname_with_scope();
  1153. auto from_allreduce_actor = dynamic_cast<KernelActor *>(FetchActor(from_allreduce_node->fullname_with_scope()));
  1154. auto to_send_actor = dynamic_cast<KernelActor *>(FetchActor(to_send_node->fullname_with_scope()));
  1155. auto to_recv_actor = dynamic_cast<KernelActor *>(FetchActor(to_recv_node->fullname_with_scope()));
  1156. MS_EXCEPTION_IF_NULL(from_allreduce_actor);
  1157. MS_EXCEPTION_IF_NULL(to_send_actor);
  1158. MS_EXCEPTION_IF_NULL(to_recv_actor);
  1159. // from_allreduce_actor --> to_send_actor
  1160. AddControlArrow(from_allreduce_actor, to_send_actor);
  1161. // to_send_actor --> to_recv_actor
  1162. AddControlArrow(to_send_actor, to_recv_actor);
  1163. // to_recv_actor --> outputs of from_allreduce_actor
  1164. for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) {
  1165. auto output_actor = dynamic_cast<KernelActor *>(FetchActor(output_data_arrow->to_op_id_.Name()));
  1166. if (output_actor != nullptr) {
  1167. AddControlArrow(to_recv_actor, output_actor);
  1168. }
  1169. }
  1170. // In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be
  1171. // reused only when the recv node runs finished, which is expressed by the reference count increased.
  1172. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(from_allreduce_node); ++i) {
  1173. auto device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(from_allreduce_node, i, false);
  1174. MS_EXCEPTION_IF_NULL(device_tensor);
  1175. UpdateRefCount(device_tensor.get());
  1176. (void)to_recv_actor->external_reference_tensors_.emplace_back(device_tensor.get());
  1177. }
  1178. }
  1179. }
  1180. void GraphScheduler::LinkGlobalControlArrow(ActorSet *const actor_set, const std::vector<CNodePtr> &communication_nodes,
  1181. const std::vector<KernelActor *> &auto_monad_actors,
  1182. const GraphCompilerInfo &graph_compiler_info) {
  1183. MS_EXCEPTION_IF_NULL(actor_set);
  1184. // Link the control arrows by the communication nodes to ensure communication nodes running order.
  1185. LinkControlArrowByCommunicationNode(communication_nodes, graph_compiler_info);
  1186. // Auto monad actor may modify the device tensor store.
  1187. LinkDeviceTensorStoreForAutoMonadActor(auto_monad_actors);
  1188. // BuildNoInputKernelActor depends on whether kernel actors have input, so must be behind the link of kernel actors.
  1189. actor_set->no_input_kernel_actors_ = BuildNoInputKernelActor(actor_set, graph_compiler_info.strategy_);
  1190. // Link the control arrows of data prepare actor, which depends on the no input kernel actors.
  1191. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) || (!IsSingleOpActorSet(actor_set))) {
  1192. LinkControlArrowForDataPrepareActor(actor_set->data_prepare_actor_.get(), actor_set,
  1193. graph_compiler_info.control_node_parser_);
  1194. }
  1195. LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), actor_set,
  1196. graph_compiler_info.control_node_parser_);
  1197. }
  1198. void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
  1199. const GraphCompilerInfo &graph_compiler_info) {
  1200. const size_t kCommunicationNodesMinNum = 2;
  1201. if (communication_nodes.size() < kCommunicationNodesMinNum) {
  1202. return;
  1203. }
  1204. // Ensure communication node to execute orderly.
  1205. for (size_t i = 1; i < communication_nodes.size(); ++i) {
  1206. auto from_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i - 1]->fullname_with_scope()));
  1207. auto to_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i]->fullname_with_scope()));
  1208. MS_EXCEPTION_IF_NULL(from_actor);
  1209. MS_EXCEPTION_IF_NULL(to_actor);
  1210. AddControlArrow(from_actor, to_actor);
  1211. }
  1212. // Ensure all actors execute orderly to optimize the execution performance in the multi device scenario currently.
  1213. // Using the multi stream to optimize the performance in the future.
  1214. for (auto &graph : graph_compiler_info.graphs_) {
  1215. MS_EXCEPTION_IF_NULL(graph);
  1216. auto &execution_order = graph->execution_order();
  1217. for (size_t i = 1; i < execution_order.size(); ++i) {
  1218. auto from_actor = dynamic_cast<KernelActor *>(FetchActor(execution_order[i - 1]->fullname_with_scope()));
  1219. auto to_actor = dynamic_cast<KernelActor *>(FetchActor(execution_order[i]->fullname_with_scope()));
  1220. if ((from_actor != nullptr) && (to_actor != nullptr)) {
  1221. AddControlArrow(from_actor, to_actor);
  1222. }
  1223. }
  1224. }
  1225. }
  1226. void GraphScheduler::LinkControlArrowForDataPrepareActor(DataPrepareActor *data_prepare_actor,
  1227. const ActorSet *actor_set,
  1228. const ControlNodeParserPtr &parser) {
  1229. MS_EXCEPTION_IF_NULL(data_prepare_actor);
  1230. MS_EXCEPTION_IF_NULL(actor_set);
  1231. MS_EXCEPTION_IF_NULL(parser);
  1232. // Data prepare actor --> data source actor.
  1233. for (auto &data_source_actor : actor_set->data_source_actors_) {
  1234. MS_EXCEPTION_IF_NULL(data_source_actor);
  1235. AddControlArrow(data_prepare_actor, data_source_actor.get());
  1236. }
  1237. // In control flow, control arrow of no input kernel actor needs to be connected to the corresponding entrance actor.
  1238. if (!parser->IsInited()) {
  1239. // Data prepare actor --> no input kernel actor.
  1240. for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
  1241. MS_EXCEPTION_IF_NULL(no_input_kernel_actor);
  1242. AddControlArrow(data_prepare_actor, no_input_kernel_actor.get());
  1243. }
  1244. }
  1245. // Data prepare actor --> loop count actor.
  1246. if ((actor_set->data_source_actors_.size() + actor_set->no_input_kernel_actors_.size() == 0) &&
  1247. (actor_set->loop_count_actor_ != nullptr)) {
  1248. AddControlArrow(data_prepare_actor, actor_set->loop_count_actor_.get());
  1249. }
  1250. }
  1251. void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
  1252. const ControlNodeParserPtr &parser) {
  1253. MS_EXCEPTION_IF_NULL(actor_set);
  1254. MS_EXCEPTION_IF_NULL(parser);
  1255. // There is no loop count actor in step mode.
  1256. if (loop_count_actor == nullptr) {
  1257. return;
  1258. }
  1259. // Collect the actors which have no output.
  1260. std::vector<MemoryAwareActor *> no_output_actors;
  1261. for (auto &super_actor : actor_set->super_kernel_actors_) {
  1262. if ((super_actor->output_data_arrows_.size() == 0) && (super_actor->output_control_arrows_.size() == 0)) {
  1263. (void)no_output_actors.emplace_back(super_actor.get());
  1264. }
  1265. }
  1266. // In control flow scenario, no output actor needs to be connected to the corresponding exit actor, not loop count.
  1267. if (!parser->IsInited()) {
  1268. for (auto &kernel_actor : actor_set->kernel_actors_) {
  1269. // The no output kernel control side in subgraph needs to be connected to the corresponding output switch actor.
  1270. if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0)) {
  1271. (void)no_output_actors.emplace_back(kernel_actor.get());
  1272. }
  1273. }
  1274. }
  1275. for (auto &data_actor : actor_set->data_source_actors_) {
  1276. if ((data_actor->output_data_arrows_.size() == 0) && (data_actor->output_control_arrows_.size() == 0)) {
  1277. (void)no_output_actors.emplace_back(data_actor.get());
  1278. }
  1279. }
  1280. for (auto &copy_actor : copy_actors_) {
  1281. if ((copy_actor->output_data_arrows_.size() == 0) && (copy_actor->output_control_arrows_.size() == 0)) {
  1282. (void)no_output_actors.emplace_back(copy_actor.get());
  1283. }
  1284. }
  1285. // No output actor --> loop count actor.
  1286. for (auto &no_output_actor : no_output_actors) {
  1287. AddControlArrow(no_output_actor, loop_count_actor);
  1288. }
  1289. // Loop count actor --> output actor.
  1290. AddControlArrow(loop_count_actor, actor_set->output_actor_.get());
  1291. // Loop count actor --> data prepare actor.
  1292. MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
  1293. loop_count_actor->data_prepare_aid_ = actor_set->data_prepare_actor_->GetAID();
  1294. }
  1295. void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
  1296. const GraphCompilerInfo &graph_compiler_info) {
  1297. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep ||
  1298. (graph_compiler_info.control_node_parser_ != nullptr && graph_compiler_info.control_node_parser_->IsInited())) {
  1299. // In control flow, the exit actor of the root graph sends output data to the output actor.
  1300. return;
  1301. }
  1302. MS_EXCEPTION_IF_NULL(to_actor);
  1303. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  1304. const auto &graph = graph_compiler_info.graphs_[i];
  1305. MS_EXCEPTION_IF_NULL(graph);
  1306. auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
  1307. std::set<std::vector<size_t>> unique_output_positions;
  1308. std::set<KernelWithIndex> unique_outputs;
  1309. for (const auto &output : outputs) {
  1310. if (IsInternalParameter(output.first, graph)) {
  1311. MS_LOG(INFO) << "Ignore the internal parameter node:" << output.first->DebugString();
  1312. continue;
  1313. }
  1314. (void)unique_outputs.insert(output);
  1315. }
  1316. for (const auto &output_with_index : unique_outputs) {
  1317. MS_EXCEPTION_IF_NULL(output_with_index.first);
  1318. auto origin_output_with_index = FetchFrontNodeWithIndexByGraphOutput(output_with_index, graph);
  1319. const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
  1320. if (iter == graph_compiler_info.origin_outputs_order_.end()) {
  1321. continue;
  1322. }
  1323. // Skip duplicate position.
  1324. if (unique_output_positions.count(iter->second) > 0) {
  1325. continue;
  1326. }
  1327. (void)unique_output_positions.insert(iter->second);
  1328. for (auto &output_position : iter->second) {
  1329. if (output_position >= to_actor->device_contexts_.size()) {
  1330. MS_LOG(EXCEPTION) << "The output position is out of range.";
  1331. }
  1332. to_actor->device_contexts_[output_position] = graph_compiler_info.device_contexts_[i];
  1333. // The graph output is from device tensor store.
  1334. if (IsPersistentDeviceTensor(output_with_index.first)) {
  1335. (void)to_actor->device_tensor_store_keys_.emplace_back(output_position, output_with_index.first);
  1336. auto device_tensor = AnfAlgo::GetMutableOutputAddr(output_with_index.first, output_with_index.second, false);
  1337. MS_EXCEPTION_IF_NULL(device_tensor);
  1338. // The output actor need use the relevant information of node to create output tensor.
  1339. device_tensor->SetNodeIndex(output_with_index.first, output_with_index.second);
  1340. continue;
  1341. }
  1342. // The graph output is from kernel actor or data source actor.
  1343. auto kernel_type = KernelTransformType::kUnknown;
  1344. std::string kernel_name = "";
  1345. FetchKernelTransformTypeAndName(output_with_index.first, graph, graph_compiler_info, &kernel_type,
  1346. &kernel_name);
  1347. auto from_actor = dynamic_cast<AbstractActor *>(FetchActor(kernel_name));
  1348. if (from_actor == nullptr) {
  1349. continue;
  1350. }
  1351. auto real_from_kernel = output_with_index.first;
  1352. // Update the real node in the host data source actor.
  1353. if (kernel_type == KernelTransformType::kHostDataSourceActor) {
  1354. auto host_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
  1355. MS_EXCEPTION_IF_NULL(host_queue_ds_actor);
  1356. auto position = host_queue_ds_actor->FetchNodePosition(output_with_index.first);
  1357. real_from_kernel = host_queue_ds_actor->FetchNode(position);
  1358. UpdateRefCount(output_with_index.first, output_with_index.second, true);
  1359. }
  1360. AddResultArrow(from_actor, to_actor, real_from_kernel, output_with_index.second, output_position);
  1361. }
  1362. }
  1363. }
  1364. }
  1365. void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors) {
  1366. const size_t kNeedUpdateDeviceTensorStoreNum = 2;
  1367. for (auto &kernel_actor : auto_monad_actors) {
  1368. MS_EXCEPTION_IF_NULL(kernel_actor);
  1369. for (auto &device_tensor_store_key : kernel_actor->device_tensor_store_keys_) {
  1370. auto device_tensors = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get());
  1371. if (device_tensors.size() < kNeedUpdateDeviceTensorStoreNum) {
  1372. continue;
  1373. }
  1374. // Create the copy actor.
  1375. std::string name = "copy_from:" + kernel_actor->GetAID().Name() +
  1376. "_device_tensor_store:" + device_tensor_store_key.second->fullname_with_scope();
  1377. if (FetchActor(name) != nullptr) {
  1378. continue;
  1379. }
  1380. auto copy_actor = std::make_shared<CopyActor>(name, memory_manager_aid_);
  1381. MS_EXCEPTION_IF_NULL(copy_actor);
  1382. (void)copy_actors_.emplace_back(copy_actor);
  1383. InsertActor(copy_actor.get());
  1384. // Set the member of the copy actor.
  1385. (void)copy_actor->device_tensor_store_keys_.emplace_back(0, device_tensor_store_key.second);
  1386. auto input_device_context = kernel_actor->device_contexts_[0];
  1387. (void)copy_actor->device_contexts_.emplace_back(input_device_context);
  1388. auto another_device_tensor = (device_tensors[0]->DeviceType() == input_device_context->GetDeviceAddressType())
  1389. ? device_tensors[1]
  1390. : device_tensors[0];
  1391. MS_EXCEPTION_IF_NULL(another_device_tensor);
  1392. auto another_device_type = another_device_tensor->DeviceType();
  1393. const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
  1394. {device::kDeviceTypeToName.at(another_device_type), input_device_context->device_context_key().device_id_});
  1395. MS_EXCEPTION_IF_NULL(another_device_context);
  1396. (void)copy_actor->device_contexts_.emplace_back(another_device_context);
  1397. MS_LOG(INFO) << "The kernel actor: " << kernel_actor->GetAID().Name()
  1398. << "has control arrows number:" << kernel_actor->output_control_arrows_.size();
  1399. // Link from copy actor to kernel actor users.
  1400. for (auto &output_contorl : kernel_actor->output_control_arrows_) {
  1401. (void)copy_actor->output_control_arrows_.emplace_back(output_contorl);
  1402. }
  1403. // Move the control arrows from kernel actor to kernel actor users.
  1404. kernel_actor->output_control_arrows_.clear();
  1405. // Link from kernel actor to copy actor.
  1406. AddControlArrow(kernel_actor, copy_actor.get());
  1407. }
  1408. }
  1409. }
  1410. void GraphScheduler::AddDeviceTensorStore(const AnfNode *anf_node, const DeviceTensorPtr &device_tensor) {
  1411. MS_EXCEPTION_IF_NULL(device_tensor);
  1412. DeviceTensorStore::GetInstance().Insert(const_cast<AnfNode *>(anf_node), device_tensor);
  1413. UpdateRefCount(device_tensor.get(), true);
  1414. }
  1415. void GraphScheduler::AddDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor,
  1416. const AnfNodePtr &from_kernel, size_t from_output_index, size_t to_input_index) {
  1417. MS_EXCEPTION_IF_NULL(from_actor);
  1418. MS_EXCEPTION_IF_NULL(to_actor);
  1419. auto data_arrow = std::make_shared<DataArrow>(from_output_index, to_actor->GetAID(), to_input_index);
  1420. (void)from_actor->output_data_arrows_.emplace_back(data_arrow);
  1421. (void)from_actor->output_data_nodes_.emplace_back(from_kernel);
  1422. to_actor->input_datas_num_++;
  1423. (void)to_actor->input_data_arrow_aids_.emplace_back(from_actor->GetAID());
  1424. if (from_kernel == nullptr) {
  1425. return;
  1426. }
  1427. // Update the reference count of from_kernel.
  1428. // The device address of super kernel actor can't be changed, so set the max reference count.
  1429. if ((from_actor->type_ == KernelTransformType::kSuperKernelActor) ||
  1430. (to_actor->type_ == KernelTransformType::kSuperKernelActor)) {
  1431. UpdateRefCount(from_kernel, from_output_index, true);
  1432. } else {
  1433. UpdateRefCount(from_kernel, from_output_index, false);
  1434. }
  1435. }
  1436. void GraphScheduler::AddResultArrow(AbstractActor *const from_actor, OutputActor *const to_actor,
  1437. const AnfNodePtr &from_kernel, size_t from_output_index, size_t output_position) {
  1438. MS_EXCEPTION_IF_NULL(from_actor);
  1439. MS_EXCEPTION_IF_NULL(to_actor);
  1440. MS_EXCEPTION_IF_NULL(from_kernel);
  1441. auto result_arrow = std::make_shared<DataArrow>(from_output_index, to_actor->GetAID(), output_position);
  1442. (void)from_actor->output_data_arrows_.insert(from_actor->output_data_arrows_.begin(), result_arrow);
  1443. (void)from_actor->output_data_nodes_.insert(from_actor->output_data_nodes_.begin(), from_kernel);
  1444. to_actor->input_datas_num_++;
  1445. (void)to_actor->input_data_arrow_aids_.emplace_back(from_actor->GetAID());
  1446. auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false);
  1447. MS_EXCEPTION_IF_NULL(device_tensor);
  1448. // The output actor need use the relevant information of node to create output tensor.
  1449. device_tensor->SetNodeIndex(from_kernel, from_output_index);
  1450. if (from_actor->type_ == KernelTransformType::kSuperKernelActor) {
  1451. (void)to_actor->output_address_persisted_nodes_.insert(from_kernel);
  1452. }
  1453. // The device tensor of graph out need be taken over by host tensor, so set the max reference count.
  1454. UpdateRefCount(device_tensor.get(), true);
  1455. }
  1456. void GraphScheduler::AddControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor) {
  1457. MS_EXCEPTION_IF_NULL(from_actor);
  1458. MS_EXCEPTION_IF_NULL(to_actor);
  1459. (void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
  1460. to_actor->input_controls_num_++;
  1461. (void)to_actor->input_control_arrow_aids_.emplace_back(from_actor->GetAID());
  1462. }
  1463. void GraphScheduler::CheckActorValid(const ActorSet *actor_set) const {
  1464. MS_EXCEPTION_IF_NULL(actor_set);
  1465. auto actors = CollectActors(actor_set);
  1466. for (auto &actor : actors) {
  1467. MS_EXCEPTION_IF_NULL(actor);
  1468. if (actor->type_ >= KernelTransformType::kSwitchActor) {
  1469. continue;
  1470. }
  1471. if ((actor->input_datas_num_ != actor->input_data_arrow_aids_.size()) ||
  1472. (actor->input_controls_num_ != actor->input_control_arrow_aids_.size())) {
  1473. MS_LOG(EXCEPTION) << "The input num of " << actor->GetAID().Name()
  1474. << " is wrong, expect data num: " << actor->input_datas_num_
  1475. << ", actual data num: " << actor->input_data_arrow_aids_.size()
  1476. << ", expect control num: " << actor->input_controls_num_
  1477. << ", actual control num: " << actor->input_control_arrow_aids_.size();
  1478. }
  1479. if ((actor->type_ != KernelTransformType::kOutputActor) && (actor->output_data_arrows_.size() == 0) &&
  1480. (actor->output_control_arrows_.size() == 0)) {
  1481. MS_LOG(EXCEPTION) << actor->GetAID().Name() << " has no user.";
  1482. }
  1483. if ((actor->type_ != KernelTransformType::kDataPrepareActor) && (actor->input_datas_num_ == 0) &&
  1484. (actor->input_controls_num_ == 0)) {
  1485. MS_LOG(EXCEPTION) << actor->GetAID().Name() << " has no source.";
  1486. }
  1487. // Check the input of kernel actors and copy actors.
  1488. if ((actor->type_ == KernelTransformType::kKernelActor) || (actor->type_ == KernelTransformType::kCopyActor)) {
  1489. size_t expect_toal_input_num = 1;
  1490. if (actor->type_ == KernelTransformType::kKernelActor) {
  1491. auto kernel_actor = dynamic_cast<KernelActor *>(actor.get());
  1492. MS_EXCEPTION_IF_NULL(kernel_actor);
  1493. expect_toal_input_num = AnfAlgo::GetInputTensorNum(kernel_actor->kernel_);
  1494. }
  1495. auto input_data_num = actor->input_datas_num_;
  1496. auto device_tensor_store_num = actor->device_tensor_store_keys_.size();
  1497. if (input_data_num + device_tensor_store_num != expect_toal_input_num) {
  1498. MS_LOG(EXCEPTION) << "The input building of " << actor->GetAID().Name()
  1499. << " is wrong, input data num: " << input_data_num
  1500. << ", device tensor store num: " << device_tensor_store_num
  1501. << ", total input num: " << expect_toal_input_num;
  1502. }
  1503. }
  1504. }
  1505. // Check the output actor.
  1506. auto output_actor = actor_set->output_actor_;
  1507. MS_EXCEPTION_IF_NULL(output_actor);
  1508. if (output_actor->input_datas_num_ + output_actor->device_tensor_store_keys_.size() != output_actor->outputs_num_) {
  1509. MS_LOG(EXCEPTION) << "The outputs num of output actor is wrong, the total outputs num: "
  1510. << output_actor->outputs_num_ << ", the input data arrows num: " << output_actor->input_datas_num_
  1511. << ", the device tensor store num: " << output_actor->device_tensor_store_keys_.size();
  1512. }
  1513. }
  1514. void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info) {
  1515. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  1516. const auto &graph = graph_compiler_info.graphs_[i];
  1517. const auto &device_context = graph_compiler_info.device_contexts_[i];
  1518. MS_EXCEPTION_IF_NULL(graph);
  1519. MS_EXCEPTION_IF_NULL(device_context);
  1520. for (auto &value_node : graph->graph_value_nodes()) {
  1521. MS_EXCEPTION_IF_NULL(value_node);
  1522. if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
  1523. MS_LOG(INFO) << "The device address is not exist: " << value_node->ToString();
  1524. continue;
  1525. }
  1526. auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0, false);
  1527. const auto &front_node = FetchFrontNodeByBackendNode(value_node, graph);
  1528. device_tensor->SetNodeIndex(value_node, 0);
  1529. AddDeviceTensorStore(front_node.get(), device_tensor);
  1530. }
  1531. for (auto &input_node : graph->input_nodes()) {
  1532. MS_EXCEPTION_IF_NULL(input_node);
  1533. AnfNodePtr sub_front_node = nullptr;
  1534. if (IsInternalParameter(input_node, graph)) {
  1535. auto front_output_with_index = graph->GetFrontNodeByInternalParameter(input_node);
  1536. sub_front_node = front_output_with_index.first;
  1537. } else if (IsPersistentDeviceTensor(input_node) || HasAbstractRef(input_node)) {
  1538. sub_front_node = FetchFrontNodeByBackendNode(input_node, graph);
  1539. }
  1540. if (sub_front_node == nullptr) {
  1541. continue;
  1542. }
  1543. // The sub front nodes share the device tensor store with the root front node.
  1544. MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
  1545. auto front_node = graph_compiler_info.control_node_parser_->FetchRootGraphFrontNodeBySubFrontNode(sub_front_node);
  1546. MS_EXCEPTION_IF_NULL(front_node);
  1547. MS_LOG(DEBUG) << "Graph id:" << graph->graph_id() << ", sub front node:" << sub_front_node->DebugString()
  1548. << ", root front node:" << front_node->DebugString();
  1549. auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
  1550. MS_EXCEPTION_IF_NULL(device_tensor);
  1551. if (IsPersistentDeviceTensor(input_node)) {
  1552. device_tensor->SetNodeIndex(input_node, 0);
  1553. AddDeviceTensorStore(front_node.get(), device_tensor);
  1554. }
  1555. // Share the weight in the host and device, then input_node is internal parameter and front_node is weight.
  1556. if (!IsPersistentDeviceTensor(front_node)) {
  1557. continue;
  1558. }
  1559. // If the device tensor store of this device type is not exist, then create the new device tensor of this type.
  1560. if (DeviceTensorStore::GetInstance().Fetch(front_node.get(), device_context->GetDeviceAddressType()) == nullptr) {
  1561. MS_LOG(INFO) << "Fetch no device tensor store by:" << front_node->fullname_with_scope()
  1562. << ", type:" << device_context->GetDeviceAddressType();
  1563. auto other_type_device_tensor = device_context->CreateDeviceAddress(
  1564. nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id());
  1565. other_type_device_tensor->SetNodeIndex(input_node, 0);
  1566. AddDeviceTensorStore(front_node.get(), other_type_device_tensor);
  1567. }
  1568. }
  1569. }
  1570. }
  1571. void GraphScheduler::FetchKernelTransformTypeAndName(const AnfNodePtr &node, const KernelGraphPtr &graph,
  1572. const GraphCompilerInfo &graph_compiler_info,
  1573. KernelTransformType *const kernel_type,
  1574. std::string *const kernel_name) {
  1575. MS_EXCEPTION_IF_NULL(graph);
  1576. MS_EXCEPTION_IF_NULL(kernel_type);
  1577. MS_EXCEPTION_IF_NULL(kernel_name);
  1578. // In sink mode, the data exchange between child graphs is expressed as parameters. These parameters are stored
  1579. // in the graph and should be obtained from the super kernel actor.
  1580. if (graph->is_executing_sink() && ((node == nullptr) || node->isa<CNode>() || graph->IsChildGraphResult(node))) {
  1581. *kernel_type = KernelTransformType::kSuperKernelActor;
  1582. *kernel_name = graph->ToString() + "_SuperKernelActor";
  1583. return;
  1584. }
  1585. MS_EXCEPTION_IF_NULL(node);
  1586. if (IsDeviceQueueDSActor(node, graph_compiler_info.strategy_)) {
  1587. *kernel_type = KernelTransformType::kDeviceDataSourceActor;
  1588. *kernel_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
  1589. } else if (IsHostQueueDSActor(node, graph, graph_compiler_info.origin_parameters_order_,
  1590. graph_compiler_info.strategy_)) {
  1591. *kernel_type = KernelTransformType::kHostDataSourceActor;
  1592. *kernel_name = graph_compiler_info.name_ + "_HostDSActor";
  1593. } else if (IsKernelActor(node, graph_compiler_info.strategy_)) {
  1594. *kernel_type = KernelTransformType::kKernelActor;
  1595. *kernel_name = node->fullname_with_scope();
  1596. } else if (IsInternalParameter(node, graph)) {
  1597. *kernel_type = KernelTransformType::kInternalParameter;
  1598. *kernel_name = "";
  1599. } else if (IsPersistentDeviceTensor(node)) {
  1600. *kernel_type = KernelTransformType::kDeviceTensorStore;
  1601. *kernel_name = "";
  1602. } else {
  1603. // May exist the from kernel that no need link in the pynative mode.
  1604. MS_LOG(DEBUG) << "Invalid from kernel: " << node->DebugString();
  1605. *kernel_type = KernelTransformType::kUnknown;
  1606. *kernel_name = "";
  1607. }
  1608. }
  1609. void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const {
  1610. MS_EXCEPTION_IF_NULL(actor_set);
  1611. const auto &context_ptr = MsContext::GetInstance();
  1612. MS_EXCEPTION_IF_NULL(context_ptr);
  1613. auto save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  1614. if (!save_graphs) {
  1615. return;
  1616. }
  1617. std::string filename = GetSaveGraphsPathName("actor_set_" + actor_set->name_ + ".ir");
  1618. std::ofstream ofs(filename);
  1619. if (!ofs.is_open()) {
  1620. MS_LOG(ERROR) << "Open file [" << filename << "] failed!";
  1621. return;
  1622. }
  1623. DumpDeviceTensorStore(graph_compiler_info, ofs);
  1624. DumpDataPrepareActor(actor_set->data_prepare_actor_, ofs);
  1625. DumpDSActors(actor_set->data_source_actors_, ofs);
  1626. DumpKernelActors(actor_set->kernel_actors_, ofs);
  1627. DumpSuperKernelActors(actor_set->super_kernel_actors_, ofs);
  1628. DumpNoInputKernelActors(actor_set->no_input_kernel_actors_, ofs);
  1629. DumpCopyActors(actor_set->copy_actors_, ofs);
  1630. DumpLoopCountActor(actor_set->loop_count_actor_, ofs);
  1631. DumpOutputActor(actor_set->output_actor_, ofs);
  1632. DumpControlActors(actor_set->control_actors_, ofs);
  1633. }
  1634. void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const {
  1635. ofs << "[Device tensor stores]\n";
  1636. for (const auto &graph : graph_compiler_info.graphs_) {
  1637. MS_EXCEPTION_IF_NULL(graph);
  1638. ofs << "\tgraph_id:" << graph->graph_id() << "\tis_executing_sink:" << graph->is_executing_sink()
  1639. << "\tis_loop_count_sink:" << graph->is_loop_count_sink()
  1640. << "\texecution_strategy:" << graph_compiler_info.strategy_ << "\n";
  1641. for (auto &value_node : graph->graph_value_nodes()) {
  1642. MS_EXCEPTION_IF_NULL(value_node);
  1643. if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
  1644. continue;
  1645. }
  1646. const auto &front_node = FetchFrontNodeByBackendNode(value_node, graph);
  1647. MS_EXCEPTION_IF_NULL(front_node);
  1648. const auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
  1649. ofs << "\t\tdevice tensor key:" << front_node->fullname_with_scope() << "\tvalue size:" << device_tensors.size()
  1650. << "\n";
  1651. for (const auto &device_tensor : device_tensors) {
  1652. MS_EXCEPTION_IF_NULL(device_tensor);
  1653. ofs << "\t\t\tdevice tensor value:" << device_tensor << "\tptr:" << device_tensor->GetPtr()
  1654. << "\tsize:" << device_tensor->GetSize() << "\toriginal_ref_count:" << device_tensor->original_ref_count()
  1655. << "\tdevice_type:" << device_tensor->DeviceType() << "\n ";
  1656. }
  1657. }
  1658. for (auto &input_node : graph->input_nodes()) {
  1659. MS_EXCEPTION_IF_NULL(input_node);
  1660. if (!IsPersistentDeviceTensor(input_node)) {
  1661. continue;
  1662. }
  1663. const auto &sub_front_node = FetchFrontNodeByBackendNode(input_node, graph);
  1664. // The sub front nodes share the device tensor store with the root front node.
  1665. auto front_node = sub_front_node;
  1666. if (graph_compiler_info.control_node_parser_ != nullptr) {
  1667. front_node = graph_compiler_info.control_node_parser_->FetchRootGraphFrontNodeBySubFrontNode(sub_front_node);
  1668. }
  1669. const auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
  1670. MS_EXCEPTION_IF_NULL(front_node);
  1671. ofs << "\t\tdevice tensor key:" << front_node->fullname_with_scope() << "\tvalue size:" << device_tensors.size()
  1672. << "\n";
  1673. for (const auto &device_tensor : device_tensors) {
  1674. MS_EXCEPTION_IF_NULL(device_tensor);
  1675. ofs << "\t\t\tdevice tensor value:" << device_tensor << "\tptr:" << device_tensor->GetPtr()
  1676. << "\tsize:" << device_tensor->GetSize() << "\toriginal_ref_count:" << device_tensor->original_ref_count()
  1677. << "\tdevice_type:" << device_tensor->DeviceType() << "\n ";
  1678. }
  1679. }
  1680. ofs << "\n";
  1681. }
  1682. }
  1683. } // namespace runtime
  1684. } // namespace mindspore