You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_scheduler.cc 90 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/framework/graph_scheduler.h"
  17. #include "runtime/framework/actor/memory_manager_actor.h"
  18. #include "runtime/framework/actor/debug_actor.h"
  19. #include "runtime/framework/actor/recorder_actor.h"
  20. #include "runtime/hardware/device_context_manager.h"
  21. #include "mindrt/src/actor/actormgr.h"
  22. #include "mindrt/include/async/async.h"
  23. #include "backend/session/anf_runtime_algorithm.h"
  24. #include "backend/optimizer/common/helper.h"
  25. #include "utils/config_manager.h"
  26. #include "utils/log_adapter.h"
  27. #include "utils/convert_utils.h"
  28. #include "utils/ms_context.h"
  29. #include "utils/profile.h"
  30. #if !defined(_WIN32) && !defined(_WIN64)
  31. #include "utils/signal_util.h"
  32. #endif
  33. #ifndef ENABLE_SECURITY
  34. #include "debug/data_dump/dump_json_parser.h"
  35. #endif
  36. #ifdef ENABLE_DUMP_IR
  37. #include "debug/rdr/recorder_manager.h"
  38. #include "debug/rdr/running_data_recorder.h"
  39. #endif
  40. #ifdef ENABLE_DEBUGGER
  41. #include "debug/debugger/debugger.h"
  42. #endif
  43. #include "profiler/device/profiling.h"
  44. #include "debug/common.h"
  45. namespace mindspore {
  46. namespace runtime {
  47. namespace {
  48. bool IsNeedInsertCopyActor(const DeviceContext *from_device_context, const DeviceContext *to_device_context) {
  49. MS_EXCEPTION_IF_NULL(from_device_context);
  50. MS_EXCEPTION_IF_NULL(to_device_context);
  51. if (from_device_context->GetDeviceAddressType() == to_device_context->GetDeviceAddressType()) {
  52. return false;
  53. } else {
  54. return true;
  55. }
  56. }
  57. inline bool IsSingleOpActorSet(const ActorSet *actor_set) {
  58. MS_EXCEPTION_IF_NULL(actor_set);
  59. return actor_set->kernel_actors_.size() == 1;
  60. }
  61. // Convert the actors vector by the actor set.
  62. std::vector<AbstractActorPtr> CollectActors(const ActorSet *actor_set) {
  63. MS_EXCEPTION_IF_NULL(actor_set);
  64. std::vector<AbstractActorPtr> actors;
  65. if (actor_set->data_prepare_actor_ != nullptr) {
  66. (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->data_prepare_actor_));
  67. }
  68. for (auto &data_source_actor : actor_set->data_source_actors_) {
  69. MS_EXCEPTION_IF_NULL(data_source_actor);
  70. (void)actors.emplace_back(static_cast<AbstractActorPtr>(data_source_actor));
  71. }
  72. for (auto &kernel_actor : actor_set->kernel_actors_) {
  73. MS_EXCEPTION_IF_NULL(kernel_actor);
  74. (void)actors.emplace_back(static_cast<AbstractActorPtr>(kernel_actor));
  75. }
  76. for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
  77. MS_EXCEPTION_IF_NULL(super_kernel_actor);
  78. (void)actors.emplace_back(static_cast<AbstractActorPtr>(super_kernel_actor));
  79. }
  80. for (auto &copy_actor : actor_set->copy_actors_) {
  81. MS_EXCEPTION_IF_NULL(copy_actor);
  82. (void)actors.emplace_back(static_cast<AbstractActorPtr>(copy_actor));
  83. }
  84. if (actor_set->loop_count_actor_ != nullptr) {
  85. (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->loop_count_actor_));
  86. }
  87. if (actor_set->output_actor_ != nullptr) {
  88. (void)actors.emplace_back(static_cast<AbstractActorPtr>(actor_set->output_actor_));
  89. }
  90. if (actor_set->control_actors_ != nullptr) {
  91. const auto &control_actor_set = actor_set->control_actors_;
  92. for (auto &switch_actor : control_actor_set->switch_actors_) {
  93. MS_EXCEPTION_IF_NULL(switch_actor);
  94. (void)actors.emplace_back(static_cast<AbstractActorPtr>(switch_actor));
  95. }
  96. for (auto &gather_actor : control_actor_set->gather_actors_) {
  97. MS_EXCEPTION_IF_NULL(gather_actor);
  98. (void)actors.emplace_back(static_cast<AbstractActorPtr>(gather_actor));
  99. }
  100. for (auto &entrance_actor : control_actor_set->entrance_actors_) {
  101. MS_EXCEPTION_IF_NULL(entrance_actor);
  102. (void)actors.emplace_back(static_cast<AbstractActorPtr>(entrance_actor));
  103. }
  104. for (auto &exit_actor : control_actor_set->exit_actors_) {
  105. MS_EXCEPTION_IF_NULL(exit_actor);
  106. (void)actors.emplace_back(static_cast<AbstractActorPtr>(exit_actor));
  107. }
  108. for (auto &stack_actor : control_actor_set->stack_actors_) {
  109. MS_EXCEPTION_IF_NULL(stack_actor);
  110. (void)actors.emplace_back(static_cast<AbstractActorPtr>(stack_actor));
  111. }
  112. }
  113. return actors;
  114. }
  115. void ClearNodeInfo(const KernelGraphPtr &graph) {
  116. MS_EXCEPTION_IF_NULL(graph);
  117. // Clear input parameter device tensor and device tensor store.
  118. for (const auto &input_node : graph->input_nodes()) {
  119. MS_EXCEPTION_IF_NULL(input_node);
  120. if (!input_node->isa<Parameter>()) {
  121. continue;
  122. }
  123. auto parameter = input_node->cast<ParameterPtr>();
  124. MS_EXCEPTION_IF_NULL(parameter);
  125. parameter->DecreaseUsedGraphCount();
  126. // Only the parameter has no graph used, then clear the device tensor.
  127. if (parameter->used_graph_count() != 0) {
  128. continue;
  129. }
  130. auto front_input_node = FetchFrontNodeByBackendNode(input_node, graph);
  131. DeviceTensorStore::GetInstance().Remove(front_input_node.get());
  132. size_t output_num = AnfAlgo::GetOutputTensorNum(input_node);
  133. for (size_t index = 0; index < output_num; ++index) {
  134. if (AnfAlgo::OutputAddrExist(input_node, index)) {
  135. AnfAlgo::SetOutputAddr(nullptr, index, input_node.get());
  136. }
  137. }
  138. }
  139. // Clear input value node device tensor and device tensor store.
  140. for (const auto &value_node : graph->graph_value_nodes()) {
  141. auto front_value_node = FetchFrontNodeByBackendNode(value_node, graph);
  142. DeviceTensorStore::GetInstance().Remove(front_value_node.get());
  143. if (AnfAlgo::OutputAddrExist(value_node, 0)) {
  144. AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
  145. }
  146. }
  147. // Clear cnode device tensor.
  148. for (const auto &cnode : graph->execution_order()) {
  149. size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
  150. for (size_t index = 0; index < output_num; ++index) {
  151. if (AnfAlgo::OutputAddrExist(cnode, index)) {
  152. AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
  153. }
  154. }
  155. }
  156. }
  157. #if !defined(_WIN32) && !defined(_WIN64)
  158. void IntHandler(int, siginfo_t *, void *) {
  159. int this_pid = getpid();
  160. MS_LOG(WARNING) << "Process " << this_pid << " receive KeyboardInterrupt signal.";
  161. (void)kill(this_pid, SIGTERM);
  162. }
  163. #endif
  164. } // namespace
  165. void GraphScheduler::Clear(const ActorInfo &actor_info, const std::vector<KernelGraphPtr> &graphs) noexcept {
  166. // Terminate the actors of actor info.
  167. if (actors_.count(actor_info) > 0) {
  168. auto actor_manager = ActorMgr::GetActorMgrRef();
  169. if (actor_manager == nullptr) {
  170. MS_LOG(ERROR) << "Actor manager is not exist.";
  171. return;
  172. }
  173. auto actor_set = actors_[actor_info];
  174. auto base_actors = CollectActors(actor_set.get());
  175. for (auto &base_actor : base_actors) {
  176. MS_EXCEPTION_IF_NULL(base_actor);
  177. EraseActor(base_actor->GetAID().Name());
  178. actor_manager->Terminate(base_actor->GetAID());
  179. }
  180. }
  181. // Clear device tensor and device tensor store.
  182. for (auto &graph : graphs) {
  183. ClearNodeInfo(graph);
  184. }
  185. // Clear global maps of actor info.
  186. (void)actors_.erase(actor_info);
  187. }
  188. void GraphScheduler::Clear() {
  189. // Terminate all actors.
  190. auto actor_manager = ActorMgr::GetActorMgrRef();
  191. MS_EXCEPTION_IF_NULL(actor_manager);
  192. actor_manager->Finalize();
  193. // Clear the member of DeviceTensorStore.
  194. DeviceTensorStore::GetInstance().Clear();
  195. // Clear global maps.
  196. actors_.clear();
  197. ClearAllActors();
  198. }
  199. void GraphScheduler::ClearActorData(const ActorSet *actor_set) {
  200. MS_EXCEPTION_IF_NULL(actor_set);
  201. control_node_scheduler_.ClearActorData(actor_set->control_actors_.get());
  202. }
  203. using DataArrowLinkFunc = void (GraphScheduler::*)(AbstractActor *const, AbstractActor *const, const KernelWithIndex &,
  204. const KernelWithIndex &, const KernelGraphPtr &);
  205. static std::map<KernelTransformType, DataArrowLinkFunc> kKernelTypeToLinkFunc;
  206. void GraphScheduler::Initialize() {
  207. if (init_) {
  208. return;
  209. }
  210. init_ = true;
  211. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kDeviceDataSourceActor,
  212. &GraphScheduler::LinkDataArrowForBaseActor);
  213. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kHostDataSourceActor,
  214. &GraphScheduler::LinkDataArrowForHostDSActor);
  215. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kKernelActor, &GraphScheduler::LinkDataArrowForKernelActor);
  216. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kSuperKernelActor,
  217. &GraphScheduler::LinkDataArrowForBaseActor);
  218. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kDeviceTensorStore,
  219. &GraphScheduler::LinkDataArrowForDeviceTensorStore);
  220. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kInternalParameter,
  221. &GraphScheduler::LinkDataArrowForInternalParameter);
  222. // Create the thread pool of actor runtime and Set the OMP_NUM_THREADS env.
  223. size_t actor_thread_num = 0;
  224. size_t actor_and_kernel_thread_num = 0;
  225. ComputeThreadNums(&actor_thread_num, &actor_and_kernel_thread_num);
  226. auto actor_manager = ActorMgr::GetActorMgrRef();
  227. MS_EXCEPTION_IF_NULL(actor_manager);
  228. auto ret = actor_manager->Initialize(true, actor_thread_num, actor_and_kernel_thread_num);
  229. if (ret != MINDRT_OK) {
  230. MS_LOG(EXCEPTION) << "Actor manager init failed.";
  231. }
  232. (void)common::SetOMPThreadNum();
  233. auto OMP_thread_num_used = common::GetEnv("OMP_NUM_THREADS");
  234. MS_LOG(INFO) << "The actor thread number: " << actor_thread_num
  235. << ", the kernel thread number: " << (actor_and_kernel_thread_num - actor_thread_num)
  236. << ", the used OMP thread number: " << OMP_thread_num_used;
  237. BuildAndScheduleGlobalActor();
  238. }
  239. void GraphScheduler::BuildAndScheduleGlobalActor() {
  240. auto actor_manager = ActorMgr::GetActorMgrRef();
  241. MS_EXCEPTION_IF_NULL(actor_manager);
  242. // Create and schedule memory manager actor.
  243. auto memory_manager_actor = std::make_shared<MemoryManagerActor>();
  244. MS_EXCEPTION_IF_NULL(memory_manager_actor);
  245. memory_manager_aid_ = memory_manager_actor->GetAID();
  246. auto base_actor = static_cast<ActorReference>(memory_manager_actor);
  247. // Bind single thread to response to memory alloc and free quickly.
  248. (void)actor_manager->Spawn(base_actor, false);
  249. // Create and schedule recorder actor.
  250. auto recorder_actor = std::make_shared<RecorderActor>();
  251. MS_EXCEPTION_IF_NULL(recorder_actor);
  252. recorder_aid_ = &(recorder_actor->GetAID());
  253. auto base_recorder_actor = static_cast<ActorReference>(recorder_actor);
  254. (void)actor_manager->Spawn(base_recorder_actor, true);
  255. // Create and schedule debug actor.
  256. #ifndef ENABLE_SECURITY
  257. bool debugger_actor_need = DumpJsonParser::GetInstance().e2e_dump_enabled();
  258. #endif
  259. #ifdef ENABLE_DEBUGGER
  260. if (Debugger::GetInstance()->DebuggerBackendEnabled()) {
  261. debugger_actor_need = true;
  262. }
  263. #endif
  264. #ifndef ENABLE_SECURITY
  265. if (debugger_actor_need) {
  266. auto debug_actor = std::make_shared<DebugActor>();
  267. MS_EXCEPTION_IF_NULL(debug_actor);
  268. debug_aid_ = &(debug_actor->GetAID());
  269. auto base_debug_actor = static_cast<ActorReference>(debug_actor);
  270. (void)actor_manager->Spawn(base_debug_actor, true);
  271. }
  272. #endif
  273. }
  274. ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info) {
  275. MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor begin.";
  276. if (graph_compiler_info.graphs_.size() == 0) {
  277. MS_LOG(EXCEPTION) << "The number of graphs is zero.";
  278. }
  279. if (graph_compiler_info.graphs_.size() != graph_compiler_info.device_contexts_.size()) {
  280. MS_LOG(EXCEPTION) << "The number of graphs is not equal to the number of device contexts.";
  281. }
  282. PersistDeviceTensor(graph_compiler_info);
  283. const auto &actor_set = Build(graph_compiler_info);
  284. MS_EXCEPTION_IF_NULL(actor_set);
  285. CacheGraphOutputToActor(graph_compiler_info);
  286. Link(actor_set.get(), graph_compiler_info);
  287. // The copy actors are built in the link, so need push into the actor set after link.
  288. actor_set->copy_actors_ = copy_actors_;
  289. DumpActor(actor_set.get(), graph_compiler_info);
  290. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
  291. CheckActorValid(actor_set.get());
  292. }
  293. MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor end.";
  294. // Local maps and vectors clear.
  295. graph_output_to_actor_.clear();
  296. copy_actors_.clear();
  297. return actor_set.get();
  298. }
  299. void GraphScheduler::Schedule(const ActorSet *actor_set) {
  300. MS_EXCEPTION_IF_NULL(actor_set);
  301. auto actors = CollectActors(actor_set);
  302. // Schedule actors.
  303. auto actor_manager = ActorMgr::GetActorMgrRef();
  304. MS_EXCEPTION_IF_NULL(actor_manager);
  305. for (auto actor : actors) {
  306. (void)actor_manager->Spawn(actor);
  307. }
  308. }
  309. void GraphScheduler::Run(ActorSet *const actor_set, const std::vector<DeviceContext *> &device_contexts,
  310. const std::vector<std::vector<TensorPtr>> &input_tensors,
  311. const std::vector<TensorPtr> &input_tensors_with_value_node, GraphExecutionStrategy strategy) {
  312. MS_EXCEPTION_IF_NULL(actor_set);
  313. MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
  314. #if !defined(_WIN32) && !defined(_WIN64)
  315. SignalGuard sg(IntHandler);
  316. #endif
  317. // Construct OpContext.
  318. OpContext<DeviceTensor> op_context;
  319. std::vector<Promise<int>> result(1);
  320. op_context.sequential_num_ = RandInt::Instance().Get();
  321. op_context.results_ = &result;
  322. if ((strategy == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {
  323. actor_set->data_prepare_actor_->PrepareData(input_tensors, &op_context);
  324. MS_EXCEPTION_IF_NULL(actor_set->kernel_actors_[0]);
  325. actor_set->kernel_actors_[0]->RunOpControlWithInputTensor(nullptr, &op_context, &input_tensors_with_value_node);
  326. return;
  327. }
  328. // Trigger data prepare actor running.
  329. MS_EXCEPTION_IF_NULL(ActorMgr::GetActorMgrRef());
  330. auto thread_pool = ActorMgr::GetActorMgrRef()->GetActorThreadPool();
  331. MS_EXCEPTION_IF_NULL(thread_pool);
  332. if (actor_set->is_multi_thread_execution_) {
  333. thread_pool->SetSpinCountMaxValue();
  334. }
  335. ActorDispatcher::is_multi_thread_execution(actor_set->is_multi_thread_execution_);
  336. double start_time = GetTime();
  337. ActorDispatcher::Send(actor_set->data_prepare_actor_->GetAID(), &DataPrepareActor::PrepareData, input_tensors,
  338. &op_context);
  339. // Get the run result.
  340. auto result_future = result[0].GetFuture();
  341. result_future.Wait();
  342. MsException::Instance().CheckException();
  343. thread_pool->SetSpinCountMinValue();
  344. if (!result_future.IsOK()) {
  345. #ifdef ENABLE_DUMP_IR
  346. mindspore::RDR::TriggerAll();
  347. #endif
  348. // When temporary variable 'op_context' has beed set failed status, the main thread need wait other threads until
  349. // they finish respective task, otherwise segmentation fault will happen when these task access 'op_context',
  350. // because it has been destroyed.
  351. std::mutex mutex;
  352. std::unique_lock<std::mutex> locker(mutex);
  353. std::condition_variable thread_blocker;
  354. const int64_t kTimeToWait = 2;
  355. thread_blocker.wait_for(locker, std::chrono::seconds(kTimeToWait));
  356. MS_LOG(EXCEPTION) << op_context.error_info_;
  357. }
  358. // Sync device stream.
  359. if (strategy == GraphExecutionStrategy::kPipeline) {
  360. std::set<DeviceContext *> sync_stream_device_contexts;
  361. for (auto &device_context : device_contexts) {
  362. MS_EXCEPTION_IF_NULL(device_context);
  363. if ((sync_stream_device_contexts.count(device_context) == 0) && (!device_context->SyncStream())) {
  364. MS_LOG(EXCEPTION) << "Sync stream failed:" << device_context->device_context_key().ToString();
  365. }
  366. (void)sync_stream_device_contexts.insert(device_context);
  367. }
  368. }
  369. double end_time = GetTime();
  370. const size_t kSecondsToMilliseconds = 1000;
  371. SetActorExecutionStrategy(actor_set, strategy, (end_time - start_time) * kSecondsToMilliseconds);
  372. }
  373. void GraphScheduler::SetActorExecutionStrategy(ActorSet *const actor_set, GraphExecutionStrategy strategy,
  374. double execution_time) {
  375. MS_EXCEPTION_IF_NULL(actor_set);
  376. MS_EXCEPTION_IF_NULL(actor_set->loop_count_actor_);
  377. ++actor_set->execution_count_;
  378. MS_LOG(DEBUG) << "Execution count: " << actor_set->execution_count_ << ", execution time cost: " << execution_time
  379. << " ms in multi thread or not: " << actor_set->is_multi_thread_execution_ << ".";
  380. #if defined(_WIN32) || defined(_WIN64)
  381. return;
  382. #endif
  383. // The step mode uses the default multi thread.
  384. if (strategy == GraphExecutionStrategy::kStep) {
  385. return;
  386. }
  387. if ((actor_set->copy_actors_.size() > 0) || (actor_set->super_kernel_actors_.size() > 0) ||
  388. (actor_set->kernel_actors_.size() > ActorDispatcher::kSingleThreadExecutionActorMaxNum) ||
  389. (actor_set->loop_count_actor_->loop_count() > 1)) {
  390. return;
  391. }
  392. if ((actor_set->is_multi_thread_execution_) &&
  393. (actor_set->execution_count_ >= ActorDispatcher::kMultiThreadExecutionCountBegin) &&
  394. (actor_set->execution_count_ <= ActorDispatcher::kMultiThreadExecutionCountEnd)) {
  395. actor_set->multi_thread_execution_time_ += execution_time;
  396. if (actor_set->execution_count_ == ActorDispatcher::kMultiThreadExecutionCountEnd) {
  397. actor_set->multi_thread_execution_time_ /=
  398. (ActorDispatcher::kMultiThreadExecutionCountEnd - ActorDispatcher::kMultiThreadExecutionCountBegin + 1);
  399. actor_set->is_multi_thread_execution_ = false;
  400. }
  401. return;
  402. }
  403. if ((!actor_set->is_multi_thread_execution_) &&
  404. (actor_set->execution_count_ >= ActorDispatcher::kSingleThreadExecutionCountBegin) &&
  405. (actor_set->execution_count_ <= ActorDispatcher::kSingleThreadExecutionCountEnd)) {
  406. actor_set->single_thread_execution_time_ += execution_time;
  407. if (actor_set->execution_count_ == ActorDispatcher::kSingleThreadExecutionCountEnd) {
  408. actor_set->single_thread_execution_time_ /=
  409. (ActorDispatcher::kSingleThreadExecutionCountEnd - ActorDispatcher::kSingleThreadExecutionCountBegin + 1);
  410. actor_set->is_multi_thread_execution_ =
  411. (actor_set->multi_thread_execution_time_ <= actor_set->single_thread_execution_time_) ? true : false;
  412. MS_LOG(INFO) << "Multi thread execution time cost: " << actor_set->multi_thread_execution_time_
  413. << " ms, single thread execution time cost: " << actor_set->single_thread_execution_time_
  414. << " ms, decide to use multi thread execution or not: " << actor_set->is_multi_thread_execution_
  415. << ".";
  416. }
  417. return;
  418. }
  419. }
  420. ActorSet *GraphScheduler::Fetch(const ActorInfo &actor_info) const {
  421. auto iter = actors_.find(actor_info);
  422. if (iter != actors_.end()) {
  423. return iter->second.get();
  424. } else {
  425. MS_LOG(ERROR) << "Can't find the actors map of " << actor_info;
  426. return nullptr;
  427. }
  428. }
  429. ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info) {
  430. auto actor_set = std::make_shared<ActorSet>(graph_compiler_info.name_);
  431. MS_EXCEPTION_IF_NULL(actor_set);
  432. (void)actors_.emplace(actor_set->name_, actor_set);
  433. auto host_queue = std::make_shared<HostTensorQueue>();
  434. actor_set->data_source_actors_ = BuildDataSourceActor(graph_compiler_info, host_queue);
  435. actor_set->kernel_actors_ = BuildKernelActor(graph_compiler_info);
  436. actor_set->super_kernel_actors_ = BuildSuperKernelActor(graph_compiler_info);
  437. actor_set->loop_count_actor_ = BuildLoopCountActor(graph_compiler_info);
  438. actor_set->output_actor_ = BuildOutputActor(graph_compiler_info);
  439. actor_set->data_prepare_actor_ =
  440. BuildDataPrepareActor(graph_compiler_info, actor_set->data_source_actors_, host_queue);
  441. actor_set->control_actors_ = control_node_scheduler_.Build(graph_compiler_info);
  442. return actor_set;
  443. }
  444. void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info) {
  445. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
  446. return;
  447. }
  448. for (const auto &graph : graph_compiler_info.graphs_) {
  449. MS_EXCEPTION_IF_NULL(graph);
  450. auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
  451. for (const auto &output_with_index : outputs) {
  452. auto output_kernel = output_with_index.first;
  453. MS_EXCEPTION_IF_NULL(output_kernel);
  454. auto origin_output_with_index = graph->GetFrontNodeWithIndexByGraphOutput(output_with_index);
  455. if (origin_output_with_index.first == nullptr) {
  456. MS_LOG(WARNING) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
  457. << " with index: " << output_with_index.second << " has no front node.";
  458. continue;
  459. }
  460. auto kernel_type = FetchKernelTransformType(output_kernel, graph, graph_compiler_info.origin_parameters_order_);
  461. auto output_actor = FetchActor(kernel_type, graph_compiler_info.name_, output_kernel, graph);
  462. if (output_actor == nullptr) {
  463. MS_LOG(INFO) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
  464. << " with index:" << output_with_index.second
  465. << " is not actor, and the kernel type is:" << kernel_type;
  466. }
  467. auto output_actor_name = (output_actor != nullptr) ? output_actor->GetAID().Name() : "";
  468. (void)graph_output_to_actor_.emplace(origin_output_with_index, GraphOutputPair(output_actor, output_with_index));
  469. MS_LOG(INFO) << "Cache the graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
  470. << " with index:" << output_with_index.second << " to actor:" << output_actor_name
  471. << ", from front node:" << origin_output_with_index.first->fullname_with_scope()
  472. << " with index:" << origin_output_with_index.second;
  473. }
  474. }
  475. }
  476. void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) {
  477. MS_EXCEPTION_IF_NULL(actor_set);
  478. std::vector<AbstractActor *> auto_monad_actors;
  479. std::vector<CNodePtr> communication_nodes;
  480. for (const auto &graph : graph_compiler_info.graphs_) {
  481. MS_EXCEPTION_IF_NULL(graph);
  482. if (graph->execution_order().empty()) {
  483. MS_LOG(INFO) << "The graph " << graph->graph_id() << " is an empty graph and skips linking.";
  484. continue;
  485. }
  486. if (graph->is_executing_sink()) {
  487. LinkDataArrowInSinkMode(graph, graph_compiler_info, &auto_monad_actors);
  488. } else {
  489. LinkDataArrowInNonSinkMode(graph, graph_compiler_info, &auto_monad_actors, &communication_nodes);
  490. }
  491. }
  492. LinkGlobalControlArrow(actor_set, communication_nodes, auto_monad_actors, graph_compiler_info);
  493. LinkOutputResultArrowForOutputActor(actor_set->output_actor_.get(), graph_compiler_info);
  494. // Link the arrow in the control flow scene.
  495. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline &&
  496. graph_compiler_info.control_node_parser_ != nullptr && graph_compiler_info.control_node_parser_->IsInited()) {
  497. control_node_scheduler_.Link(actor_set, graph_compiler_info);
  498. }
  499. }
  500. std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
  501. const HostTensorQueuePtr &host_queue) {
  502. std::vector<DataSourceActorPtr> data_source_actors;
  503. HostQueueDSActorPtr host_queue_ds_actor = nullptr;
  504. size_t data_node_position = 0;
  505. mindspore::HashMap<AnfNodePtr, size_t> front_node_position_temp_map;
  506. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  507. const auto &graph = graph_compiler_info.graphs_[i];
  508. const auto &device_context = graph_compiler_info.device_contexts_[i];
  509. MS_EXCEPTION_IF_NULL(graph);
  510. // Build host queue data source actor.
  511. const std::vector<AnfNodePtr> &input_nodes = graph->input_nodes();
  512. for (size_t j = 0; j < input_nodes.size(); j++) {
  513. const auto &input_node = input_nodes[j];
  514. MS_EXCEPTION_IF_NULL(input_node);
  515. if (IsHostQueueDSActor(input_node, graph, graph_compiler_info.origin_parameters_order_,
  516. graph_compiler_info.strategy_)) {
  517. if (host_queue_ds_actor == nullptr) {
  518. auto actor_name = graph_compiler_info.name_ + "_HostDSActor";
  519. MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
  520. host_queue_ds_actor = std::make_shared<HostQueueDataSourceActor>(actor_name, 1, memory_manager_aid_, nullptr,
  521. nullptr, host_queue);
  522. InsertActor(host_queue_ds_actor.get());
  523. (void)data_source_actors.emplace_back(host_queue_ds_actor);
  524. }
  525. const auto &front_node = FetchFrontNodeByBackendNode(input_node, graph);
  526. // In the scenario where multiple backend nodes correspond to the same front node, only the first backend node
  527. // is saved in the host queue data source actor.
  528. if (front_node_position_temp_map.count(front_node) > 0) {
  529. (void)host_queue_ds_actor->data_node_position_map_.emplace(input_node,
  530. front_node_position_temp_map[front_node]);
  531. continue;
  532. }
  533. (void)host_queue_ds_actor->data_nodes_.emplace_back(input_node);
  534. (void)host_queue_ds_actor->device_contexts_.emplace_back(device_context);
  535. (void)host_queue_ds_actor->data_node_position_map_.emplace(input_node, data_node_position);
  536. // In control flow, need to rely on the front node to find the location of the corresponding real parameter.
  537. (void)host_queue_ds_actor->data_node_position_map_.emplace(front_node, data_node_position);
  538. (void)front_node_position_temp_map.emplace(front_node, data_node_position);
  539. data_node_position++;
  540. }
  541. }
  542. // The graph sink mode has no device queue data source actor.
  543. if (!graph->is_executing_sink()) {
  544. // Build device queue data source actor.
  545. const auto &execution_order = graph->execution_order();
  546. const auto &iter =
  547. std::find_if(execution_order.begin(), execution_order.end(), [&graph_compiler_info](const CNodePtr &node) {
  548. return IsDeviceQueueDSActor(node, graph_compiler_info.strategy_);
  549. });
  550. if (iter != execution_order.end()) {
  551. auto actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
  552. MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
  553. auto device_queue_ds_actor = std::make_shared<DeviceQueueDataSourceActor>(
  554. actor_name, 1, device_context, memory_manager_aid_, debug_aid_, recorder_aid_);
  555. MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
  556. InsertActor(device_queue_ds_actor.get());
  557. (void)data_source_actors.emplace_back(device_queue_ds_actor);
  558. device_queue_ds_actor->data_kernel_ = *iter;
  559. device_queue_ds_actor->kernel_info_ = dynamic_cast<device::KernelInfo *>((*iter)->kernel_info());
  560. }
  561. }
  562. }
  563. MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
  564. const auto &front_to_backend_parameter = graph_compiler_info.control_node_parser_->front_to_backend_parameters_;
  565. // Initialize the parameter in the control node, first get all the front parameters in the control node, then find
  566. // the corresponding backend parameter from the map, and insert it into the host data source actor
  567. const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters();
  568. for (const auto &parameter : control_node_parameters) {
  569. if (IsPersistentDeviceTensor(parameter)) {
  570. continue;
  571. }
  572. auto backend_iter = front_to_backend_parameter.find({parameter, 0});
  573. if (backend_iter == front_to_backend_parameter.end() || backend_iter->second.empty()) {
  574. MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(parameter);
  575. }
  576. if (host_queue_ds_actor == nullptr) {
  577. auto actor_name = graph_compiler_info.name_ + "_HostDSActor";
  578. MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
  579. host_queue_ds_actor =
  580. std::make_shared<HostQueueDataSourceActor>(actor_name, 1, memory_manager_aid_, nullptr, nullptr, host_queue);
  581. InsertActor(host_queue_ds_actor.get());
  582. (void)data_source_actors.emplace_back(host_queue_ds_actor);
  583. }
  584. if (host_queue_ds_actor->data_node_position_map_.find(parameter) !=
  585. host_queue_ds_actor->data_node_position_map_.end()) {
  586. continue;
  587. }
  588. const auto &backend_node = backend_iter->second.begin()->first;
  589. auto iter = find(host_queue_ds_actor->data_nodes_.begin(), host_queue_ds_actor->data_nodes_.end(), backend_node);
  590. if (iter != host_queue_ds_actor->data_nodes_.end()) {
  591. (void)host_queue_ds_actor->data_node_position_map_.emplace(parameter,
  592. iter - host_queue_ds_actor->data_nodes_.begin());
  593. } else {
  594. (void)host_queue_ds_actor->data_node_position_map_.emplace(parameter, host_queue_ds_actor->data_nodes_.size());
  595. (void)host_queue_ds_actor->data_node_position_map_.emplace(backend_iter->second.begin()->first,
  596. host_queue_ds_actor->data_nodes_.size());
  597. (void)host_queue_ds_actor->data_nodes_.emplace_back(backend_iter->second.begin()->first);
  598. (void)host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.begin()->second);
  599. }
  600. }
  601. return data_source_actors;
  602. }
  603. std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompilerInfo &graph_compiler_info) {
  604. std::vector<KernelActorPtr> kernel_actors;
  605. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  606. const auto &graph = graph_compiler_info.graphs_[i];
  607. const auto &device_context = graph_compiler_info.device_contexts_[i];
  608. MS_EXCEPTION_IF_NULL(graph);
  609. if (graph->is_executing_sink()) {
  610. continue;
  611. }
  612. auto execution_order = graph->execution_order();
  613. // Single op graph in step mode, kernel actor executes synchronously.
  614. bool is_single_op_graph = execution_order.size() == 1;
  615. GraphExecutionStrategy strategy = graph_compiler_info.strategy_;
  616. if (strategy == GraphExecutionStrategy::kStep) {
  617. strategy = (is_single_op_graph ? strategy : GraphExecutionStrategy::kPipeline);
  618. }
  619. for (auto &kernel : execution_order) {
  620. MS_EXCEPTION_IF_NULL(kernel);
  621. if (IsKernelActor(kernel, graph_compiler_info.strategy_) && (!IsSkippedKernelActor(kernel))) {
  622. auto kernel_actor = std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context,
  623. memory_manager_aid_, debug_aid_, recorder_aid_, strategy);
  624. MS_EXCEPTION_IF_NULL(kernel_actor);
  625. InsertActor(kernel_actor.get());
  626. (void)kernel_actors.emplace_back(kernel_actor);
  627. }
  628. }
  629. }
  630. return kernel_actors;
  631. }
  632. std::vector<SuperKernelActorPtr> GraphScheduler::BuildSuperKernelActor(const GraphCompilerInfo &graph_compiler_info) {
  633. std::vector<SuperKernelActorPtr> super_kernel_actors;
  634. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  635. const auto &graph = graph_compiler_info.graphs_[i];
  636. const auto &device_context = graph_compiler_info.device_contexts_[i];
  637. MS_EXCEPTION_IF_NULL(graph);
  638. if (!graph->is_executing_sink()) {
  639. continue;
  640. }
  641. if (graph->execution_order().empty()) {
  642. MS_LOG(INFO) << "The graph " << graph->graph_id() << " is an empty graph and skips building.";
  643. continue;
  644. }
  645. auto actor_name = graph->ToString() + "_SuperKernelActor";
  646. auto super_kernel_actor =
  647. std::make_shared<SuperKernelActor>(actor_name, graph, device_context, memory_manager_aid_, nullptr, nullptr);
  648. MS_EXCEPTION_IF_NULL(super_kernel_actor);
  649. InsertActor(super_kernel_actor.get());
  650. (void)super_kernel_actors.emplace_back(super_kernel_actor);
  651. }
  652. return super_kernel_actors;
  653. }
  654. LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info) {
  655. auto actor_set = Fetch(graph_compiler_info.name_);
  656. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {
  657. return nullptr;
  658. }
  659. auto loop_count = ConfigManager::GetInstance().iter_num();
  660. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) ||
  661. (graph_compiler_info.graphs_.size() == 1 && graph_compiler_info.graphs_[0]->is_loop_count_sink())) {
  662. loop_count = 1;
  663. }
  664. auto actor_name = graph_compiler_info.name_ + "_LoopCountActor";
  665. auto loop_count_actor =
  666. std::make_shared<LoopCountActor>(actor_name, loop_count, memory_manager_aid_, debug_aid_, recorder_aid_);
  667. MS_LOG(INFO) << "Create loop count actor: " << actor_name;
  668. MS_EXCEPTION_IF_NULL(loop_count_actor);
  669. InsertActor(loop_count_actor.get());
  670. return loop_count_actor;
  671. }
  672. OutputActorPtr GraphScheduler::BuildOutputActor(const GraphCompilerInfo &graph_compiler_info) {
  673. auto actor_set = Fetch(graph_compiler_info.name_);
  674. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) && IsSingleOpActorSet(actor_set)) {
  675. return nullptr;
  676. }
  677. auto loop_count = ConfigManager::GetInstance().iter_num();
  678. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) ||
  679. (graph_compiler_info.graphs_.size() == 1 && graph_compiler_info.graphs_[0]->is_loop_count_sink())) {
  680. loop_count = 1;
  681. }
  682. auto actor_name = graph_compiler_info.name_ + "_" + "OutputActor";
  683. auto output_actor = std::make_shared<OutputActor>(actor_name, loop_count, graph_compiler_info.outputs_num_);
  684. MS_LOG(INFO) << "Create output actor: " << actor_name;
  685. MS_EXCEPTION_IF_NULL(output_actor);
  686. InsertActor(output_actor.get());
  687. return output_actor;
  688. }
  689. DataPrepareActorPtr GraphScheduler::BuildDataPrepareActor(const GraphCompilerInfo &graph_compiler_info,
  690. const std::vector<DataSourceActorPtr> &data_source_actors,
  691. const HostTensorQueuePtr &host_queue) {
  692. HostQueueDSActorPtr host_queue_ds_actor = nullptr;
  693. auto iter = std::find_if(data_source_actors.begin(), data_source_actors.end(), [&](const auto &data_source_actor) {
  694. return data_source_actor->type_ == KernelTransformType::kHostDataSourceActor;
  695. });
  696. if (iter != data_source_actors.end()) {
  697. host_queue_ds_actor = std::dynamic_pointer_cast<HostQueueDataSourceActor>(*iter);
  698. }
  699. auto actor_name = graph_compiler_info.name_ + "_DataPrepareActor";
  700. auto data_prepare_actor = std::make_shared<DataPrepareActor>(actor_name, memory_manager_aid_, debug_aid_,
  701. &graph_compiler_info, host_queue_ds_actor, host_queue);
  702. MS_LOG(INFO) << "Create data prepare actor: " << actor_name;
  703. MS_EXCEPTION_IF_NULL(data_prepare_actor);
  704. // Cache the nodes which need continuous memory.
  705. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
  706. for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) {
  707. const auto &graph = graph_compiler_info.graphs_[index];
  708. MS_EXCEPTION_IF_NULL(graph);
  709. if (graph->is_executing_sink()) {
  710. continue;
  711. }
  712. auto &execution_order = graph->execution_order();
  713. for (auto &kernel : execution_order) {
  714. if (!AnfAlgo::IsCommunicationOp(kernel)) {
  715. continue;
  716. }
  717. auto key = std::make_pair(kernel, graph_compiler_info.device_contexts_[index]);
  718. auto value = std::make_pair(false, false);
  719. if (AnfAlgo::GetInputTensorNum(kernel) > 1) {
  720. value.first = true;
  721. }
  722. if (AnfAlgo::GetOutputTensorNum(kernel) > 1) {
  723. value.second = true;
  724. }
  725. if ((value.first == true) || (value.second == true)) {
  726. data_prepare_actor->continuous_memory_nodes_[key] = value;
  727. }
  728. }
  729. }
  730. }
  731. InsertActor(data_prepare_actor.get());
  732. return data_prepare_actor;
  733. }
  734. std::vector<AbstractActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorSet *actor_set,
  735. GraphExecutionStrategy strategy) {
  736. MS_EXCEPTION_IF_NULL(actor_set);
  737. std::vector<AbstractActorPtr> no_input_kernel_actors;
  738. for (auto &super_kernel_actor : actor_set->super_kernel_actors_) {
  739. MS_EXCEPTION_IF_NULL(super_kernel_actor);
  740. if ((super_kernel_actor->input_datas_num_ == 0) && (super_kernel_actor->input_controls_num_ == 0)) {
  741. (void)no_input_kernel_actors.emplace_back(super_kernel_actor);
  742. }
  743. }
  744. for (auto &kernel_actor : actor_set->kernel_actors_) {
  745. MS_EXCEPTION_IF_NULL(kernel_actor);
  746. // Framework will trigger kernel actor running in the step execution strategy.
  747. if (strategy == GraphExecutionStrategy::kStep && IsSingleOpActorSet(actor_set)) {
  748. kernel_actor->input_controls_num_++;
  749. continue;
  750. }
  751. if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) {
  752. (void)no_input_kernel_actors.emplace_back(kernel_actor);
  753. }
  754. }
  755. return no_input_kernel_actors;
  756. }
  757. void GraphScheduler::LinkDataArrowInSinkMode(const KernelGraphPtr &graph, const GraphCompilerInfo &graph_compiler_info,
  758. std::vector<AbstractActor *> *const auto_monad_actors) {
  759. MS_EXCEPTION_IF_NULL(graph);
  760. // The data arrow linking is taken over by the control flow.
  761. if (graph_compiler_info.control_node_parser_ != nullptr &&
  762. graph_compiler_info.control_node_parser_->IsControlFlowDataArrow(graph, nullptr)) {
  763. return;
  764. }
  765. auto to_actor_name = graph->ToString() + "_SuperKernelActor";
  766. auto to_actor = dynamic_cast<SuperKernelActor *>(FetchActor(to_actor_name));
  767. MS_EXCEPTION_IF_NULL(to_actor);
  768. auto &input_nodes = graph->input_nodes();
  769. for (size_t node_index = 0; node_index < input_nodes.size(); ++node_index) {
  770. auto &input_node = input_nodes[node_index];
  771. MS_EXCEPTION_IF_NULL(input_node);
  772. if (HasAbstractMonad(input_node)) {
  773. MS_LOG(INFO) << "The graph:" << graph->graph_id()
  774. << " has abstract monad input node:" << input_node->DebugString() << ", input index:" << node_index;
  775. LinkControlArrowByAutoMonad(to_actor, input_node, graph);
  776. (void)auto_monad_actors->emplace_back(to_actor);
  777. continue; // No data arrow for monad input.
  778. }
  779. UpdateRefCount(input_node, 0, true);
  780. KernelWithIndex from_kernel_with_output_idx = std::make_pair(input_node, 0);
  781. KernelWithIndex to_kernel_with_input_idx = std::make_pair(input_node, node_index);
  782. // The gather of linking data arrows of kernel by the different from kernel type.
  783. LinkDataArrow(to_actor, graph_compiler_info, graph, from_kernel_with_output_idx, to_kernel_with_input_idx);
  784. }
  785. }
  786. void GraphScheduler::LinkDataArrowInNonSinkMode(const KernelGraphPtr &graph,
  787. const GraphCompilerInfo &graph_compiler_info,
  788. std::vector<AbstractActor *> *const auto_monad_actors,
  789. std::vector<CNodePtr> *const communication_nodes) {
  790. MS_EXCEPTION_IF_NULL(graph);
  791. MS_EXCEPTION_IF_NULL(auto_monad_actors);
  792. MS_EXCEPTION_IF_NULL(communication_nodes);
  793. const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {
  794. prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
  795. auto &execution_order = graph->execution_order();
  796. // Foreach the execution order to link the actors.
  797. for (const auto &kernel : execution_order) {
  798. MS_EXCEPTION_IF_NULL(kernel);
  799. if (AnfAlgo::IsCommunicationOp(kernel)) {
  800. (void)communication_nodes->emplace_back(kernel);
  801. }
  802. if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) {
  803. continue;
  804. }
  805. const auto &kernel_actor = dynamic_cast<KernelActor *>(FetchActor(kernel->fullname_with_scope()));
  806. MS_EXCEPTION_IF_NULL(kernel_actor);
  807. for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
  808. auto input_node = AnfAlgo::GetInputNode(kernel, i);
  809. // Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
  810. if (IsOneOfPrimitiveCNode(input_node, auto_monad_prims) || HasAbstractMonad(input_node)) {
  811. LinkControlArrowByAutoMonad(kernel_actor, input_node, graph);
  812. }
  813. if (HasAbstractMonad(input_node)) {
  814. (void)auto_monad_actors->emplace_back(kernel_actor);
  815. continue; // No data arrow for monad input.
  816. }
  817. KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
  818. KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);
  819. // The data arrow linking is taken over by the control flow.
  820. if (graph_compiler_info.control_node_parser_ != nullptr &&
  821. graph_compiler_info.control_node_parser_->IsControlFlowDataArrow(graph, from_kernel_with_output_idx.first)) {
  822. continue;
  823. }
  824. // The gather of linking data arrows of kernel by the different from kernel type.
  825. LinkDataArrow(kernel_actor, graph_compiler_info, graph, from_kernel_with_output_idx, to_kernel_with_input_idx);
  826. }
  827. }
  828. // Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
  829. LinkControlArrowBySendRecvNodes(graph);
  830. }
  831. void GraphScheduler::LinkDataArrow(AbstractActor *const to_actor, const GraphCompilerInfo &graph_compiler_info,
  832. const KernelGraphPtr &graph, const KernelWithIndex &from_kernel_with_output_idx,
  833. const KernelWithIndex &to_kernel_with_input_idx) {
  834. MS_EXCEPTION_IF_NULL(to_actor);
  835. MS_EXCEPTION_IF_NULL(graph);
  836. auto from_kernel = from_kernel_with_output_idx.first;
  837. MS_EXCEPTION_IF_NULL(from_kernel);
  838. auto kernel_type = FetchKernelTransformType(from_kernel, graph, graph_compiler_info.origin_parameters_order_,
  839. graph_compiler_info.strategy_);
  840. auto from_actor = FetchActor(kernel_type, graph_compiler_info.name_, from_kernel, graph);
  841. if (kKernelTypeToLinkFunc.count(kernel_type) == 0) {
  842. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
  843. MS_LOG(WARNING) << "Invalid from node:" << from_kernel->fullname_with_scope() << ", type:" << kernel_type;
  844. }
  845. return;
  846. }
  847. (this->*kKernelTypeToLinkFunc[kernel_type])(from_actor, to_actor, from_kernel_with_output_idx,
  848. to_kernel_with_input_idx, graph);
  849. }
  850. void GraphScheduler::LinkDataArrowForDeviceTensorStore(AbstractActor *const, AbstractActor *const to_actor,
  851. const KernelWithIndex &from_kernel_with_output_idx,
  852. const KernelWithIndex &to_kernel_with_input_idx,
  853. const KernelGraphPtr &graph) {
  854. MS_EXCEPTION_IF_NULL(to_actor);
  855. MS_EXCEPTION_IF_NULL(graph);
  856. if (to_actor->type_ == KernelTransformType::kSuperKernelActor) {
  857. return;
  858. }
  859. auto from_kernel = from_kernel_with_output_idx.first;
  860. MS_EXCEPTION_IF_NULL(from_kernel);
  861. auto device_tensor_store_key = FetchFrontNodeByBackendNode(from_kernel, graph);
  862. (void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, device_tensor_store_key);
  863. }
  864. void GraphScheduler::LinkDataArrowForInternalParameter(AbstractActor *const, AbstractActor *to_actor,
  865. const KernelWithIndex &from_kernel_with_output_idx,
  866. const KernelWithIndex &to_kernel_with_input_idx,
  867. const KernelGraphPtr &graph) {
  868. MS_EXCEPTION_IF_NULL(to_actor);
  869. MS_EXCEPTION_IF_NULL(graph);
  870. auto internal_parameter = from_kernel_with_output_idx.first;
  871. MS_EXCEPTION_IF_NULL(internal_parameter);
  872. // Parameter ---> front node.
  873. auto front_output_with_index = graph->GetFrontNodeByInternalParameter(internal_parameter);
  874. auto front_output_node = front_output_with_index.first;
  875. MS_EXCEPTION_IF_NULL(front_output_node);
  876. if (IsSwitchActor(front_output_node)) {
  877. return;
  878. }
  879. auto real_from_kernel_with_output_idx = from_kernel_with_output_idx;
  880. AbstractActor *real_from_actor = nullptr;
  881. KernelTransformType kernel_type;
  882. if (IsPersistentDeviceTensor(front_output_node)) {
  883. kernel_type = KernelTransformType::kDeviceTensorStore;
  884. } else {
  885. // front node ---> actor.
  886. if (graph_output_to_actor_.count(front_output_with_index) == 0) {
  887. MS_LOG(EXCEPTION) << "Can't find actor by front node:" << AnfAlgo::GetNodeDebugString(front_output_node)
  888. << ", internal parameter:" << AnfAlgo::GetNodeDebugString(internal_parameter);
  889. }
  890. auto actor_pair = graph_output_to_actor_[front_output_with_index];
  891. MS_EXCEPTION_IF_NULL(actor_pair.first);
  892. MS_EXCEPTION_IF_NULL(actor_pair.second.first);
  893. MS_LOG(INFO) << "Graph " << graph->graph_id() << " internal parameter:" << internal_parameter->DebugString()
  894. << ", corresponding front node:" << front_output_node->fullname_with_scope()
  895. << " with index:" << front_output_with_index.second
  896. << ", from actor:" << actor_pair.first->GetAID().Name()
  897. << " node:" << actor_pair.second.first->fullname_with_scope()
  898. << " with index:" << actor_pair.second.second << ", to actor:" << to_actor->GetAID().Name()
  899. << " with index:" << to_kernel_with_input_idx.second;
  900. real_from_actor = actor_pair.first;
  901. real_from_kernel_with_output_idx = actor_pair.second;
  902. kernel_type = actor_pair.first->type_;
  903. }
  904. if (kKernelTypeToLinkFunc.count(kernel_type) == 0) {
  905. MS_LOG(EXCEPTION) << "Invalid internal parameter:" << internal_parameter->DebugString() << ", type:" << kernel_type;
  906. }
  907. (this->*kKernelTypeToLinkFunc[kernel_type])(real_from_actor, to_actor, real_from_kernel_with_output_idx,
  908. to_kernel_with_input_idx, graph);
  909. }
  910. void GraphScheduler::LinkDataArrowForBaseActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  911. const KernelWithIndex &from_kernel_with_output_idx,
  912. const KernelWithIndex &to_kernel_with_input_idx,
  913. const KernelGraphPtr &) {
  914. MS_EXCEPTION_IF_NULL(from_actor);
  915. MS_EXCEPTION_IF_NULL(to_actor);
  916. auto from_kernel = from_kernel_with_output_idx.first;
  917. MS_EXCEPTION_IF_NULL(from_kernel);
  918. auto from_output_index = from_kernel_with_output_idx.second;
  919. auto to_input_index = to_kernel_with_input_idx.second;
  920. // Get the position of from kernel in the data source actor.
  921. auto position = from_actor->FetchNodePosition(from_kernel);
  922. if ((from_actor->device_contexts_.size() <= position) || (to_actor->device_contexts_.size() <= 0)) {
  923. MS_LOG(EXCEPTION) << "The device contexts size is wrong.";
  924. }
  925. if (IsNeedInsertCopyActor(from_actor->device_contexts_[position], to_actor->device_contexts_[0])) {
  926. LinkDataArrowForCopyActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
  927. } else {
  928. AddDataArrow(from_actor, to_actor, from_kernel, from_output_index, to_input_index);
  929. }
  930. }
  931. void GraphScheduler::LinkDataArrowForHostDSActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  932. const KernelWithIndex &from_kernel_with_output_idx,
  933. const KernelWithIndex &to_kernel_with_input_idx,
  934. const KernelGraphPtr &graph) {
  935. auto host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
  936. MS_EXCEPTION_IF_NULL(host_ds_actor);
  937. MS_EXCEPTION_IF_NULL(from_kernel_with_output_idx.first);
  938. KernelWithIndex real_from_kernel_with_output_idx = from_kernel_with_output_idx;
  939. // Get the position and real kernel by from kernel in the data source actor.
  940. auto position = host_ds_actor->FetchNodePosition(from_kernel_with_output_idx.first);
  941. real_from_kernel_with_output_idx.first = host_ds_actor->FetchNode(position);
  942. LinkDataArrowForBaseActor(from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx, graph);
  943. }
  944. void GraphScheduler::LinkDataArrowForKernelActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  945. const KernelWithIndex &from_kernel_with_output_idx,
  946. const KernelWithIndex &to_kernel_with_input_idx,
  947. const KernelGraphPtr &graph) {
  948. auto real_from_actor = from_actor;
  949. auto real_from_kernel_with_output_idx = from_kernel_with_output_idx;
  950. auto from_kernel = from_kernel_with_output_idx.first;
  951. // Update the from kernel info by the real node info.
  952. MS_EXCEPTION_IF_NULL(from_kernel);
  953. if (IsSkippedKernelActor(from_kernel)) {
  954. real_from_kernel_with_output_idx = AnfAlgo::GetPrevNodeOutput(from_kernel, 0, false);
  955. MS_EXCEPTION_IF_NULL(real_from_kernel_with_output_idx.first);
  956. LinkControlArrowBySkippedNode(to_actor, from_kernel);
  957. MS_EXCEPTION_IF_NULL(to_kernel_with_input_idx.first);
  958. MS_LOG(INFO) << "Link data arrow for inplace node, aggregate node: "
  959. << to_kernel_with_input_idx.first->fullname_with_scope()
  960. << ", aggregate input index: " << to_kernel_with_input_idx.second
  961. << ", skip node: " << from_kernel->fullname_with_scope()
  962. << ", real node: " << real_from_kernel_with_output_idx.first->fullname_with_scope();
  963. real_from_actor = FetchActor(real_from_kernel_with_output_idx.first->fullname_with_scope());
  964. MS_EXCEPTION_IF_NULL(real_from_actor);
  965. }
  966. LinkDataArrowForBaseActor(real_from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx,
  967. graph);
  968. }
  969. void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor, AbstractActor *const to_actor,
  970. const KernelWithIndex &from_kernel_with_output_idx,
  971. const KernelWithIndex &to_kernel_with_input_idx) {
  972. MS_EXCEPTION_IF_NULL(from_actor);
  973. MS_EXCEPTION_IF_NULL(to_actor);
  974. auto from_kernel = from_kernel_with_output_idx.first;
  975. MS_EXCEPTION_IF_NULL(from_kernel);
  976. std::string name = "copy_from:" + from_actor->GetAID().Name() + "_node:" + from_kernel->fullname_with_scope() +
  977. "_output_index:" + std::to_string(from_kernel_with_output_idx.second);
  978. CopyActor *copy_actor = dynamic_cast<CopyActor *>(FetchActor(name));
  979. // Link between from actor and copy actor.
  980. if (copy_actor == nullptr) {
  981. // Create the copy actor.
  982. auto copy_actor_shared_ptr = std::make_shared<CopyActor>(name, memory_manager_aid_);
  983. (void)copy_actors_.emplace_back(copy_actor_shared_ptr);
  984. copy_actor = copy_actor_shared_ptr.get();
  985. MS_EXCEPTION_IF_NULL(copy_actor);
  986. InsertActor(copy_actor);
  987. // Set the member device_contexts_ of the copy actor.
  988. auto position = from_actor->FetchNodePosition(from_kernel);
  989. if ((from_actor->device_contexts_.size() <= position) || (to_actor->device_contexts_.size() <= 0)) {
  990. MS_LOG(EXCEPTION) << "The device contexts size is wrong.";
  991. }
  992. auto from_device_context = from_actor->device_contexts_[position];
  993. auto to_device_context = to_actor->device_contexts_[0];
  994. MS_EXCEPTION_IF_NULL(from_device_context);
  995. MS_EXCEPTION_IF_NULL(to_device_context);
  996. (void)copy_actor->device_contexts_.emplace_back(from_device_context);
  997. (void)copy_actor->device_contexts_.emplace_back(to_device_context);
  998. // Set the member output_ of the copy actor.
  999. if (to_actor->type_ == KernelTransformType::kSuperKernelActor) {
  1000. copy_actor->output_ = AnfAlgo::GetMutableOutputAddr(to_kernel_with_input_idx.first, 0, false);
  1001. } else {
  1002. copy_actor->output_ =
  1003. AnfAlgo::GetPrevNodeMutableOutputAddr(to_kernel_with_input_idx.first, to_kernel_with_input_idx.second, false);
  1004. }
  1005. MS_EXCEPTION_IF_NULL(copy_actor->output_);
  1006. if (copy_actor->output_->DeviceType() != to_device_context->GetDeviceAddressType()) {
  1007. MS_LOG(EXCEPTION) << "The device type is not equal, output device type:" << copy_actor->output_->DeviceType()
  1008. << ", to device context type:" << to_device_context->GetDeviceAddressType();
  1009. }
  1010. // Link between from actor and copy actor.
  1011. AddDataArrow(from_actor, copy_actor, from_kernel, from_kernel_with_output_idx.second, 0);
  1012. }
  1013. // If the copy actor already exists, only need link between copy actor and to actor.
  1014. AddDataArrow(copy_actor, to_actor, nullptr, 0, to_kernel_with_input_idx.second);
  1015. if (to_actor->type_ == KernelTransformType::kSuperKernelActor) {
  1016. UpdateRefCount(copy_actor->output_.get(), true);
  1017. } else {
  1018. UpdateRefCount(copy_actor->output_.get(), false);
  1019. }
  1020. }
  1021. void GraphScheduler::LinkControlArrowByAutoMonad(AbstractActor *to_actor, const AnfNodePtr &from_node,
  1022. const KernelGraphPtr &graph) {
  1023. MS_EXCEPTION_IF_NULL(to_actor);
  1024. MS_EXCEPTION_IF_NULL(from_node);
  1025. MS_EXCEPTION_IF_NULL(graph);
  1026. // Find the real input node, include the monad node and make tuple node.
  1027. const std::vector<PrimitivePtr> return_types = {prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad,
  1028. prim::kPrimMakeTuple};
  1029. const auto &input_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(from_node, 0, false, return_types);
  1030. MS_EXCEPTION_IF_NULL(input_kernel_with_output_idx.first);
  1031. auto input_anfnode = input_kernel_with_output_idx.first;
  1032. CNodePtr input_cnode = nullptr;
  1033. if (input_anfnode->isa<CNode>()) {
  1034. input_cnode = input_anfnode->cast<CNodePtr>();
  1035. }
  1036. // Make tuple node needs to be expanded.
  1037. if (AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimMakeTuple)) {
  1038. MS_EXCEPTION_IF_NULL(input_cnode);
  1039. for (size_t i = 1; i < input_cnode->inputs().size(); ++i) {
  1040. LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), graph);
  1041. }
  1042. return;
  1043. }
  1044. const mindspore::HashSet<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> recursion_prims = {
  1045. prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimMakeTuple};
  1046. // Get the real depend input by monad node which needs to link the control arrow.
  1047. std::vector<AnfNodePtr> real_depend_inputs;
  1048. if (AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimDepend) ||
  1049. AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimLoad)) {
  1050. MS_EXCEPTION_IF_NULL(input_cnode);
  1051. real_depend_inputs.push_back(input_cnode->input(kDependAttachNodeIndex));
  1052. // The real input may be this scene: depend/load --> load/depend, so need add the control arrow for real input
  1053. // node in this scene.
  1054. if (IsOneOfPrimitiveCNode(input_cnode->input(kRealInputIndexInDepend), recursion_prims)) {
  1055. real_depend_inputs.push_back(input_cnode->input(kRealInputIndexInDepend));
  1056. }
  1057. } else if (AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimUpdateState)) {
  1058. MS_EXCEPTION_IF_NULL(input_cnode);
  1059. for (size_t i = kUpdateStateRealInput; i < input_cnode->inputs().size(); ++i) {
  1060. real_depend_inputs.push_back(input_cnode->input(i));
  1061. }
  1062. } else {
  1063. real_depend_inputs.push_back(input_anfnode);
  1064. }
  1065. for (const auto &real_depend_input : real_depend_inputs) {
  1066. auto real_depend_input_with_idx = AnfAlgo::VisitKernelWithReturnType(real_depend_input, 0, false, return_types);
  1067. MS_EXCEPTION_IF_NULL(real_depend_input_with_idx.first);
  1068. auto real_depend_kernel = real_depend_input_with_idx.first;
  1069. // Update the real depend kernel in the subgraphs connecting scene.
  1070. if (IsInternalParameter(real_depend_kernel, graph)) {
  1071. auto front_output_with_index = graph->GetFrontNodeByInternalParameter(real_depend_kernel);
  1072. MS_EXCEPTION_IF_NULL(front_output_with_index.first);
  1073. if (graph_output_to_actor_.count(front_output_with_index) == 0) {
  1074. if (AnfAlgo::IsCallNode(front_output_with_index.first)) {
  1075. continue;
  1076. }
  1077. MS_LOG(EXCEPTION) << "Can't find graph output by front node:" << front_output_with_index.first->DebugString();
  1078. }
  1079. real_depend_kernel = graph_output_to_actor_[front_output_with_index].second.first;
  1080. MS_EXCEPTION_IF_NULL(real_depend_kernel);
  1081. MS_LOG(INFO) << "The graph " << graph->graph_id() << " link control arrow by auto monad from internal parameter: "
  1082. << real_depend_input_with_idx.first->DebugString()
  1083. << ", front output node: " << front_output_with_index.first->fullname_with_scope()
  1084. << ", backend output node: " << real_depend_kernel->fullname_with_scope();
  1085. auto from_actor = graph_output_to_actor_[front_output_with_index].first;
  1086. if (from_actor != nullptr) {
  1087. MS_LOG(INFO) << "Link control arrow by auto monad from actor: " << from_actor->GetAID().Name()
  1088. << ", to actor: " << to_actor->GetAID().Name() << " for the graph: " << graph->graph_id();
  1089. AddControlArrow(from_actor, to_actor);
  1090. continue;
  1091. }
  1092. }
  1093. // The monad node and make tuple node need recursion.
  1094. if (IsOneOfPrimitiveCNode(real_depend_kernel, recursion_prims)) {
  1095. LinkControlArrowByAutoMonad(to_actor, real_depend_kernel, graph);
  1096. continue;
  1097. }
  1098. auto type = FetchKernelTransformType(real_depend_kernel, nullptr);
  1099. auto from_actor = FetchActor(type, "", real_depend_kernel);
  1100. if (from_actor == nullptr) {
  1101. MS_LOG(DEBUG) << "Link control arrow by auto monad from depend node:" << real_depend_kernel->fullname_with_scope()
  1102. << " is not actor for the graph: " << graph->graph_id();
  1103. continue;
  1104. }
  1105. MS_LOG(INFO) << "Link control arrow by auto monad from actor: " << from_actor->GetAID().Name()
  1106. << ", to actor: " << to_actor->GetAID().Name() << " for the graph: " << graph->graph_id();
  1107. AddControlArrow(from_actor, to_actor);
  1108. }
  1109. }
  1110. void GraphScheduler::LinkControlArrowBySkippedNode(AbstractActor *to_actor, const AnfNodePtr &skipped_node) {
  1111. MS_EXCEPTION_IF_NULL(to_actor);
  1112. MS_EXCEPTION_IF_NULL(skipped_node);
  1113. // Link the control arrow from all the inputs of skipped node to the user of skipped node.
  1114. auto input_num = AnfAlgo::GetInputTensorNum(skipped_node);
  1115. for (size_t i = 0; i < input_num; ++i) {
  1116. auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(skipped_node, i, false);
  1117. MS_EXCEPTION_IF_NULL(kernel_with_index.first);
  1118. auto from_actor = dynamic_cast<KernelActor *>(FetchActor(kernel_with_index.first->fullname_with_scope()));
  1119. MS_EXCEPTION_IF_NULL(from_actor);
  1120. MS_LOG(INFO) << "Link control arrow by skipped node: " << skipped_node->fullname_with_scope()
  1121. << ", from actor: " << from_actor->GetAID().Name() << ", to actor: " << to_actor->GetAID().Name();
  1122. AddControlArrow(from_actor, to_actor);
  1123. }
  1124. }
  1125. void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph) {
  1126. MS_EXCEPTION_IF_NULL(graph);
  1127. for (auto &from_iter : graph->allreduce_from_send_recv_pairs()) {
  1128. auto to_allreduce_node = from_iter.first;
  1129. auto from_send_node = from_iter.second.first;
  1130. auto from_recv_node = from_iter.second.second;
  1131. MS_EXCEPTION_IF_NULL(to_allreduce_node);
  1132. MS_EXCEPTION_IF_NULL(from_send_node);
  1133. MS_EXCEPTION_IF_NULL(from_recv_node);
  1134. MS_LOG(INFO) << "Link control arrow for to_allreduce_node: " << to_allreduce_node->fullname_with_scope();
  1135. auto to_allreduce_actor = dynamic_cast<KernelActor *>(FetchActor(to_allreduce_node->fullname_with_scope()));
  1136. auto from_send_actor = dynamic_cast<KernelActor *>(FetchActor(from_send_node->fullname_with_scope()));
  1137. auto from_recv_actor = dynamic_cast<KernelActor *>(FetchActor(from_recv_node->fullname_with_scope()));
  1138. MS_EXCEPTION_IF_NULL(to_allreduce_actor);
  1139. MS_EXCEPTION_IF_NULL(from_send_actor);
  1140. MS_EXCEPTION_IF_NULL(from_recv_actor);
  1141. // inputs of to_allreduce_actor --> from_send_actor
  1142. for (auto &input_aid : to_allreduce_actor->input_data_arrow_aids_) {
  1143. auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
  1144. if (input_actor != nullptr) {
  1145. AddControlArrow(input_actor, from_send_actor);
  1146. }
  1147. }
  1148. // from_send_actor --> from_recv_actor
  1149. AddControlArrow(from_send_actor, from_recv_actor);
  1150. // from_recv_actor --> to_allreduce_actor
  1151. AddControlArrow(from_recv_actor, to_allreduce_actor);
  1152. }
  1153. for (auto &to_iter : graph->allreduce_to_send_recv_pairs()) {
  1154. auto from_allreduce_node = to_iter.first;
  1155. auto to_send_node = to_iter.second.first;
  1156. auto to_recv_node = to_iter.second.second;
  1157. MS_EXCEPTION_IF_NULL(from_allreduce_node);
  1158. MS_EXCEPTION_IF_NULL(to_send_node);
  1159. MS_EXCEPTION_IF_NULL(to_recv_node);
  1160. MS_LOG(INFO) << "Link control arrow for from_allreduce_node: " << from_allreduce_node->fullname_with_scope();
  1161. auto from_allreduce_actor = dynamic_cast<KernelActor *>(FetchActor(from_allreduce_node->fullname_with_scope()));
  1162. auto to_send_actor = dynamic_cast<KernelActor *>(FetchActor(to_send_node->fullname_with_scope()));
  1163. auto to_recv_actor = dynamic_cast<KernelActor *>(FetchActor(to_recv_node->fullname_with_scope()));
  1164. MS_EXCEPTION_IF_NULL(from_allreduce_actor);
  1165. MS_EXCEPTION_IF_NULL(to_send_actor);
  1166. MS_EXCEPTION_IF_NULL(to_recv_actor);
  1167. // from_allreduce_actor --> to_send_actor
  1168. AddControlArrow(from_allreduce_actor, to_send_actor);
  1169. // to_send_actor --> to_recv_actor
  1170. AddControlArrow(to_send_actor, to_recv_actor);
  1171. // to_recv_actor --> outputs of from_allreduce_actor
  1172. for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) {
  1173. auto output_actor = dynamic_cast<KernelActor *>(FetchActor(output_data_arrow->to_op_id_.Name()));
  1174. if (output_actor != nullptr) {
  1175. AddControlArrow(to_recv_actor, output_actor);
  1176. }
  1177. }
  1178. // In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be
  1179. // reused only when the recv node runs finished, which is expressed by the reference count increased.
  1180. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(from_allreduce_node); ++i) {
  1181. auto device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(from_allreduce_node, i, false);
  1182. MS_EXCEPTION_IF_NULL(device_tensor);
  1183. UpdateRefCount(device_tensor.get());
  1184. (void)to_recv_actor->external_reference_tensors_.emplace_back(device_tensor.get());
  1185. }
  1186. }
  1187. }
  1188. void GraphScheduler::LinkGlobalControlArrow(ActorSet *const actor_set, const std::vector<CNodePtr> &communication_nodes,
  1189. const std::vector<AbstractActor *> &auto_monad_actors,
  1190. const GraphCompilerInfo &graph_compiler_info) {
  1191. MS_EXCEPTION_IF_NULL(actor_set);
  1192. // Link the control arrows by the communication nodes to ensure communication nodes running order.
  1193. LinkControlArrowByCommunicationNode(communication_nodes, graph_compiler_info);
  1194. // Auto monad actor may modify the device tensor store.
  1195. LinkDeviceTensorStoreForAutoMonadActor(auto_monad_actors);
  1196. // BuildNoInputKernelActor depends on whether kernel actors have input, so must be behind the link of kernel actors.
  1197. actor_set->no_input_kernel_actors_ = BuildNoInputKernelActor(actor_set, graph_compiler_info.strategy_);
  1198. // Link the control arrows of data prepare actor, which depends on the no input kernel actors.
  1199. if ((graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) || (!IsSingleOpActorSet(actor_set))) {
  1200. LinkControlArrowForDataPrepareActor(actor_set->data_prepare_actor_.get(), actor_set,
  1201. graph_compiler_info.control_node_parser_);
  1202. }
  1203. LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), actor_set,
  1204. graph_compiler_info.control_node_parser_);
  1205. }
  1206. void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
  1207. const GraphCompilerInfo &graph_compiler_info) {
  1208. const size_t kCommunicationNodesMinNum = 2;
  1209. if (communication_nodes.size() < kCommunicationNodesMinNum) {
  1210. return;
  1211. }
  1212. // Ensure communication node to execute orderly.
  1213. for (size_t i = 1; i < communication_nodes.size(); ++i) {
  1214. auto from_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i - 1]->fullname_with_scope()));
  1215. auto to_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i]->fullname_with_scope()));
  1216. MS_EXCEPTION_IF_NULL(from_actor);
  1217. MS_EXCEPTION_IF_NULL(to_actor);
  1218. AddControlArrow(from_actor, to_actor);
  1219. }
  1220. // Ensure all actors execute orderly to optimize the execution performance in the multi device scenario currently.
  1221. // Using the multi stream to optimize the performance in the future.
  1222. for (auto &graph : graph_compiler_info.graphs_) {
  1223. MS_EXCEPTION_IF_NULL(graph);
  1224. auto &execution_order = graph->execution_order();
  1225. for (size_t i = 1; i < execution_order.size(); ++i) {
  1226. auto from_actor = dynamic_cast<KernelActor *>(FetchActor(execution_order[i - 1]->fullname_with_scope()));
  1227. auto to_actor = dynamic_cast<KernelActor *>(FetchActor(execution_order[i]->fullname_with_scope()));
  1228. if ((from_actor != nullptr) && (to_actor != nullptr)) {
  1229. AddControlArrow(from_actor, to_actor);
  1230. }
  1231. }
  1232. }
  1233. }
  1234. void GraphScheduler::LinkControlArrowForDataPrepareActor(DataPrepareActor *data_prepare_actor,
  1235. const ActorSet *actor_set,
  1236. const ControlNodeParserPtr &parser) {
  1237. MS_EXCEPTION_IF_NULL(data_prepare_actor);
  1238. MS_EXCEPTION_IF_NULL(actor_set);
  1239. MS_EXCEPTION_IF_NULL(parser);
  1240. // Data prepare actor --> data source actor.
  1241. for (auto &data_source_actor : actor_set->data_source_actors_) {
  1242. MS_EXCEPTION_IF_NULL(data_source_actor);
  1243. AddControlArrow(data_prepare_actor, data_source_actor.get());
  1244. }
  1245. // In control flow, control arrow of no input kernel actor needs to be connected to the corresponding entrance actor.
  1246. if (!parser->IsInited()) {
  1247. // Data prepare actor --> no input kernel actor.
  1248. for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
  1249. MS_EXCEPTION_IF_NULL(no_input_kernel_actor);
  1250. AddControlArrow(data_prepare_actor, no_input_kernel_actor.get());
  1251. }
  1252. }
  1253. // Data prepare actor --> loop count actor.
  1254. if ((actor_set->data_source_actors_.size() + actor_set->no_input_kernel_actors_.size() == 0) &&
  1255. (actor_set->loop_count_actor_ != nullptr)) {
  1256. AddControlArrow(data_prepare_actor, actor_set->loop_count_actor_.get());
  1257. }
  1258. }
  1259. void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
  1260. const ControlNodeParserPtr &parser) {
  1261. MS_EXCEPTION_IF_NULL(actor_set);
  1262. MS_EXCEPTION_IF_NULL(parser);
  1263. // There is no loop count actor in step mode.
  1264. if (loop_count_actor == nullptr) {
  1265. return;
  1266. }
  1267. // Collect the actors which have no output.
  1268. std::vector<MemoryAwareActor *> no_output_actors;
  1269. for (auto &super_actor : actor_set->super_kernel_actors_) {
  1270. if ((super_actor->output_data_arrows_.size() == 0) && (super_actor->output_control_arrows_.size() == 0)) {
  1271. (void)no_output_actors.emplace_back(super_actor.get());
  1272. }
  1273. }
  1274. for (auto &kernel_actor : actor_set->kernel_actors_) {
  1275. // The no output kernel control side in subgraph needs to be connected to the corresponding output switch actor.
  1276. if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0)) {
  1277. (void)no_output_actors.emplace_back(kernel_actor.get());
  1278. }
  1279. }
  1280. for (auto &data_actor : actor_set->data_source_actors_) {
  1281. if ((data_actor->output_data_arrows_.size() == 0) && (data_actor->output_control_arrows_.size() == 0)) {
  1282. (void)no_output_actors.emplace_back(data_actor.get());
  1283. }
  1284. }
  1285. for (auto &copy_actor : copy_actors_) {
  1286. if ((copy_actor->output_data_arrows_.size() == 0) && (copy_actor->output_control_arrows_.size() == 0)) {
  1287. (void)no_output_actors.emplace_back(copy_actor.get());
  1288. }
  1289. }
  1290. // No output actor --> loop count actor.
  1291. // In control flow scenario, no output actor needs to be connected to the corresponding exit actor, not loop count.
  1292. if (!parser->IsInited()) {
  1293. for (auto &no_output_actor : no_output_actors) {
  1294. AddControlArrow(no_output_actor, loop_count_actor);
  1295. }
  1296. }
  1297. // Loop count actor --> output actor.
  1298. AddControlArrow(loop_count_actor, actor_set->output_actor_.get());
  1299. // Loop count actor --> data prepare actor.
  1300. MS_EXCEPTION_IF_NULL(actor_set->data_prepare_actor_);
  1301. loop_count_actor->data_prepare_aid_ = actor_set->data_prepare_actor_->GetAID();
  1302. }
  1303. void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
  1304. const GraphCompilerInfo &graph_compiler_info) {
  1305. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep ||
  1306. (graph_compiler_info.control_node_parser_ != nullptr && graph_compiler_info.control_node_parser_->IsInited())) {
  1307. // In control flow, the exit actor of the root graph sends output data to the output actor.
  1308. return;
  1309. }
  1310. MS_EXCEPTION_IF_NULL(to_actor);
  1311. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  1312. const auto &graph = graph_compiler_info.graphs_[i];
  1313. MS_EXCEPTION_IF_NULL(graph);
  1314. auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
  1315. std::set<std::vector<size_t>> unique_output_positions;
  1316. std::set<KernelWithIndex> unique_outputs;
  1317. for (const auto &output : outputs) {
  1318. if (IsInternalParameter(output.first, graph)) {
  1319. MS_LOG(INFO) << "Ignore the internal parameter node:" << output.first->DebugString();
  1320. continue;
  1321. }
  1322. (void)unique_outputs.insert(output);
  1323. }
  1324. for (const auto &output_with_index : unique_outputs) {
  1325. MS_EXCEPTION_IF_NULL(output_with_index.first);
  1326. auto origin_output_with_index = FetchFrontNodeWithIndexByGraphOutput(output_with_index, graph);
  1327. const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
  1328. if (iter == graph_compiler_info.origin_outputs_order_.end()) {
  1329. continue;
  1330. }
  1331. // Skip duplicate position.
  1332. if (unique_output_positions.count(iter->second) > 0) {
  1333. continue;
  1334. }
  1335. (void)unique_output_positions.insert(iter->second);
  1336. for (auto &output_position : iter->second) {
  1337. if (output_position >= to_actor->device_contexts_.size()) {
  1338. MS_LOG(EXCEPTION) << "The output position is out of range.";
  1339. }
  1340. to_actor->device_contexts_[output_position] = graph_compiler_info.device_contexts_[i];
  1341. // The graph output is from device tensor store.
  1342. if (IsPersistentDeviceTensor(output_with_index.first)) {
  1343. (void)to_actor->device_tensor_store_keys_.emplace_back(output_position, output_with_index.first);
  1344. auto device_tensor = AnfAlgo::GetMutableOutputAddr(output_with_index.first, output_with_index.second, false);
  1345. MS_EXCEPTION_IF_NULL(device_tensor);
  1346. // The output actor need use the relevant information of node to create output tensor.
  1347. device_tensor->SetNodeIndex(output_with_index.first, output_with_index.second);
  1348. continue;
  1349. }
  1350. // The graph output is from kernel actor or data source actor.
  1351. auto kernel_type = FetchKernelTransformType(
  1352. output_with_index.first, graph, graph_compiler_info.origin_parameters_order_, graph_compiler_info.strategy_);
  1353. auto from_actor = FetchActor(kernel_type, graph_compiler_info.name_, output_with_index.first, graph);
  1354. if (from_actor == nullptr) {
  1355. continue;
  1356. }
  1357. auto real_from_kernel = output_with_index.first;
  1358. // Update the real node in the host data source actor.
  1359. if (kernel_type == KernelTransformType::kHostDataSourceActor) {
  1360. auto host_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
  1361. MS_EXCEPTION_IF_NULL(host_queue_ds_actor);
  1362. auto position = host_queue_ds_actor->FetchNodePosition(output_with_index.first);
  1363. real_from_kernel = host_queue_ds_actor->FetchNode(position);
  1364. UpdateRefCount(output_with_index.first, output_with_index.second, true);
  1365. }
  1366. AddResultArrow(from_actor, to_actor, real_from_kernel, output_with_index.second, output_position);
  1367. }
  1368. }
  1369. }
  1370. }
  1371. void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<AbstractActor *> &auto_monad_actors) {
  1372. const size_t kNeedUpdateDeviceTensorStoreNum = 2;
  1373. for (auto &auto_monad_actor : auto_monad_actors) {
  1374. MS_EXCEPTION_IF_NULL(auto_monad_actor);
  1375. for (auto &device_tensor_store_key : auto_monad_actor->device_tensor_store_keys_) {
  1376. auto device_tensors = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get());
  1377. if (device_tensors.size() < kNeedUpdateDeviceTensorStoreNum) {
  1378. continue;
  1379. }
  1380. // Create the copy actor.
  1381. std::string name = "copy_from:" + auto_monad_actor->GetAID().Name() +
  1382. "_device_tensor_store:" + device_tensor_store_key.second->fullname_with_scope();
  1383. if (FetchActor(name) != nullptr) {
  1384. continue;
  1385. }
  1386. auto copy_actor = std::make_shared<CopyActor>(name, memory_manager_aid_);
  1387. MS_EXCEPTION_IF_NULL(copy_actor);
  1388. (void)copy_actors_.emplace_back(copy_actor);
  1389. InsertActor(copy_actor.get());
  1390. // Set the member of the copy actor.
  1391. (void)copy_actor->device_tensor_store_keys_.emplace_back(0, device_tensor_store_key.second);
  1392. auto input_device_context = auto_monad_actor->device_contexts_[0];
  1393. (void)copy_actor->device_contexts_.emplace_back(input_device_context);
  1394. auto another_device_tensor = (device_tensors[0]->DeviceType() == input_device_context->GetDeviceAddressType())
  1395. ? device_tensors[1]
  1396. : device_tensors[0];
  1397. MS_EXCEPTION_IF_NULL(another_device_tensor);
  1398. auto another_device_type = another_device_tensor->DeviceType();
  1399. const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
  1400. {device::kDeviceTypeToName.at(another_device_type), input_device_context->device_context_key().device_id_});
  1401. MS_EXCEPTION_IF_NULL(another_device_context);
  1402. (void)copy_actor->device_contexts_.emplace_back(another_device_context);
  1403. MS_LOG(INFO) << "The auto monad actor: " << auto_monad_actor->GetAID().Name()
  1404. << "has control arrows number:" << auto_monad_actor->output_control_arrows_.size();
  1405. // Link from copy actor to auto monad actor users.
  1406. for (auto &output_contorl : auto_monad_actor->output_control_arrows_) {
  1407. (void)copy_actor->output_control_arrows_.emplace_back(output_contorl);
  1408. }
  1409. // Move the control arrows from auto monad actor to auto monad actor users.
  1410. auto_monad_actor->output_control_arrows_.clear();
  1411. // Link from auto monad actor to copy actor.
  1412. AddControlArrow(auto_monad_actor, copy_actor.get());
  1413. }
  1414. }
  1415. }
  1416. void GraphScheduler::AddDeviceTensorStore(const AnfNode *anf_node, const DeviceTensorPtr &device_tensor) {
  1417. MS_EXCEPTION_IF_NULL(device_tensor);
  1418. DeviceTensorStore::GetInstance().Insert(const_cast<AnfNode *>(anf_node), device_tensor);
  1419. UpdateRefCount(device_tensor.get(), true);
  1420. }
  1421. void GraphScheduler::AddDataArrow(AbstractActor *const from_actor, AbstractActor *const to_actor,
  1422. const AnfNodePtr &from_kernel, size_t from_output_index, size_t to_input_index) {
  1423. MS_EXCEPTION_IF_NULL(from_actor);
  1424. MS_EXCEPTION_IF_NULL(to_actor);
  1425. auto data_arrow = std::make_shared<DataArrow>(from_output_index, to_actor->GetAID(), to_input_index);
  1426. (void)from_actor->output_data_arrows_.emplace_back(data_arrow);
  1427. (void)from_actor->output_data_nodes_.emplace_back(from_kernel);
  1428. to_actor->input_datas_num_++;
  1429. (void)to_actor->input_data_arrow_aids_.emplace_back(from_actor->GetAID());
  1430. if (from_kernel == nullptr) {
  1431. return;
  1432. }
  1433. // Update the reference count of from_kernel.
  1434. // The device address of super kernel actor can't be changed, so set the max reference count.
  1435. if ((from_actor->type_ == KernelTransformType::kSuperKernelActor) ||
  1436. (to_actor->type_ == KernelTransformType::kSuperKernelActor)) {
  1437. UpdateRefCount(from_kernel, from_output_index, true);
  1438. } else {
  1439. UpdateRefCount(from_kernel, from_output_index, false);
  1440. }
  1441. }
  1442. void GraphScheduler::AddResultArrow(AbstractActor *const from_actor, OutputActor *const to_actor,
  1443. const AnfNodePtr &from_kernel, size_t from_output_index, size_t output_position) {
  1444. MS_EXCEPTION_IF_NULL(from_actor);
  1445. MS_EXCEPTION_IF_NULL(to_actor);
  1446. MS_EXCEPTION_IF_NULL(from_kernel);
  1447. auto result_arrow = std::make_shared<DataArrow>(from_output_index, to_actor->GetAID(), output_position);
  1448. (void)from_actor->output_data_arrows_.insert(from_actor->output_data_arrows_.begin(), result_arrow);
  1449. (void)from_actor->output_data_nodes_.insert(from_actor->output_data_nodes_.begin(), from_kernel);
  1450. to_actor->input_datas_num_++;
  1451. (void)to_actor->input_data_arrow_aids_.emplace_back(from_actor->GetAID());
  1452. auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false);
  1453. MS_EXCEPTION_IF_NULL(device_tensor);
  1454. // The output actor need use the relevant information of node to create output tensor.
  1455. device_tensor->SetNodeIndex(from_kernel, from_output_index);
  1456. // The device tensor of graph out need be taken over by host tensor, so set the max reference count.
  1457. UpdateRefCount(device_tensor.get(), true);
  1458. }
  1459. void GraphScheduler::AddControlArrow(AbstractActor *const from_actor, AbstractActor *const to_actor) {
  1460. MS_EXCEPTION_IF_NULL(from_actor);
  1461. MS_EXCEPTION_IF_NULL(to_actor);
  1462. (void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
  1463. to_actor->input_controls_num_++;
  1464. (void)to_actor->input_control_arrow_aids_.emplace_back(from_actor->GetAID());
  1465. }
  1466. void GraphScheduler::CheckActorValid(const ActorSet *actor_set) const {
  1467. MS_EXCEPTION_IF_NULL(actor_set);
  1468. auto actors = CollectActors(actor_set);
  1469. for (auto &actor : actors) {
  1470. MS_EXCEPTION_IF_NULL(actor);
  1471. if (actor->type_ >= KernelTransformType::kSwitchActor) {
  1472. continue;
  1473. }
  1474. if ((actor->input_datas_num_ != actor->input_data_arrow_aids_.size()) ||
  1475. (actor->input_controls_num_ != actor->input_control_arrow_aids_.size())) {
  1476. MS_LOG(EXCEPTION) << "The input num of " << actor->GetAID().Name()
  1477. << " is wrong, expect data num: " << actor->input_datas_num_
  1478. << ", actual data num: " << actor->input_data_arrow_aids_.size()
  1479. << ", expect control num: " << actor->input_controls_num_
  1480. << ", actual control num: " << actor->input_control_arrow_aids_.size();
  1481. }
  1482. if ((actor->type_ != KernelTransformType::kOutputActor) && (actor->output_data_arrows_.size() == 0) &&
  1483. (actor->output_control_arrows_.size() == 0)) {
  1484. MS_LOG(EXCEPTION) << actor->GetAID().Name() << " has no user.";
  1485. }
  1486. if ((actor->type_ != KernelTransformType::kDataPrepareActor) && (actor->input_datas_num_ == 0) &&
  1487. (actor->input_controls_num_ == 0)) {
  1488. MS_LOG(EXCEPTION) << actor->GetAID().Name() << " has no source.";
  1489. }
  1490. // Check the input of kernel actors and copy actors.
  1491. if ((actor->type_ == KernelTransformType::kKernelActor) || (actor->type_ == KernelTransformType::kCopyActor)) {
  1492. size_t expect_toal_input_num = 1;
  1493. if (actor->type_ == KernelTransformType::kKernelActor) {
  1494. auto kernel_actor = dynamic_cast<KernelActor *>(actor.get());
  1495. MS_EXCEPTION_IF_NULL(kernel_actor);
  1496. expect_toal_input_num = AnfAlgo::GetInputTensorNum(kernel_actor->kernel_);
  1497. }
  1498. auto input_data_num = actor->input_datas_num_;
  1499. auto device_tensor_store_num = actor->device_tensor_store_keys_.size();
  1500. if (input_data_num + device_tensor_store_num != expect_toal_input_num) {
  1501. MS_LOG(EXCEPTION) << "The input building of " << actor->GetAID().Name()
  1502. << " is wrong, input data num: " << input_data_num
  1503. << ", device tensor store num: " << device_tensor_store_num
  1504. << ", total input num: " << expect_toal_input_num;
  1505. }
  1506. }
  1507. }
  1508. // Check the output actor.
  1509. auto output_actor = actor_set->output_actor_;
  1510. MS_EXCEPTION_IF_NULL(output_actor);
  1511. if (output_actor->input_datas_num_ + output_actor->device_tensor_store_keys_.size() != output_actor->outputs_num_) {
  1512. MS_LOG(EXCEPTION) << "The outputs num of output actor is wrong, the total outputs num: "
  1513. << output_actor->outputs_num_ << ", the input data arrows num: " << output_actor->input_datas_num_
  1514. << ", the device tensor store num: " << output_actor->device_tensor_store_keys_.size();
  1515. }
  1516. }
  1517. void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info) {
  1518. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  1519. const auto &graph = graph_compiler_info.graphs_[i];
  1520. const auto &device_context = graph_compiler_info.device_contexts_[i];
  1521. MS_EXCEPTION_IF_NULL(graph);
  1522. MS_EXCEPTION_IF_NULL(device_context);
  1523. for (auto &value_node : graph->graph_value_nodes()) {
  1524. MS_EXCEPTION_IF_NULL(value_node);
  1525. if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
  1526. MS_LOG(INFO) << "The device address is not exist: " << value_node->ToString();
  1527. continue;
  1528. }
  1529. auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0, false);
  1530. const auto &front_node = FetchFrontNodeByBackendNode(value_node, graph);
  1531. device_tensor->SetNodeIndex(value_node, 0);
  1532. AddDeviceTensorStore(front_node.get(), device_tensor);
  1533. }
  1534. for (auto &input_node : graph->input_nodes()) {
  1535. MS_EXCEPTION_IF_NULL(input_node);
  1536. AnfNodePtr sub_front_node = nullptr;
  1537. if (IsInternalParameter(input_node, graph)) {
  1538. auto front_output_with_index = graph->GetFrontNodeByInternalParameter(input_node);
  1539. sub_front_node = front_output_with_index.first;
  1540. } else if (IsPersistentDeviceTensor(input_node) || HasAbstractRef(input_node)) {
  1541. sub_front_node = FetchFrontNodeByBackendNode(input_node, graph);
  1542. }
  1543. if (sub_front_node == nullptr) {
  1544. continue;
  1545. }
  1546. // The sub front nodes share the device tensor store with the root front node.
  1547. MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_);
  1548. auto front_node = graph_compiler_info.control_node_parser_->FetchRootGraphFrontNodeBySubFrontNode(sub_front_node);
  1549. MS_EXCEPTION_IF_NULL(front_node);
  1550. MS_LOG(DEBUG) << "Graph id:" << graph->graph_id() << ", sub front node:" << sub_front_node->DebugString()
  1551. << ", root front node:" << front_node->DebugString();
  1552. auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
  1553. MS_EXCEPTION_IF_NULL(device_tensor);
  1554. if (IsPersistentDeviceTensor(input_node)) {
  1555. device_tensor->SetNodeIndex(input_node, 0);
  1556. AddDeviceTensorStore(front_node.get(), device_tensor);
  1557. }
  1558. // Share the weight in the host and device, then input_node is internal parameter and front_node is weight.
  1559. if (!IsPersistentDeviceTensor(front_node)) {
  1560. continue;
  1561. }
  1562. // If the device tensor store of this device type is not exist, then create the new device tensor of this type.
  1563. if (DeviceTensorStore::GetInstance().Fetch(front_node.get(), device_context->GetDeviceAddressType()) == nullptr) {
  1564. MS_LOG(INFO) << "Fetch no device tensor store by:" << front_node->fullname_with_scope()
  1565. << ", type:" << device_context->GetDeviceAddressType();
  1566. auto other_type_device_tensor = device_context->CreateDeviceAddress(
  1567. nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id());
  1568. other_type_device_tensor->SetNodeIndex(input_node, 0);
  1569. AddDeviceTensorStore(front_node.get(), other_type_device_tensor);
  1570. }
  1571. }
  1572. }
  1573. }
  1574. void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const {
  1575. MS_EXCEPTION_IF_NULL(actor_set);
  1576. const auto &context_ptr = MsContext::GetInstance();
  1577. MS_EXCEPTION_IF_NULL(context_ptr);
  1578. auto save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  1579. if (!save_graphs) {
  1580. return;
  1581. }
  1582. std::string filename = GetSaveGraphsPathName("actor_set_" + actor_set->name_ + ".ir");
  1583. std::ofstream ofs(filename);
  1584. if (!ofs.is_open()) {
  1585. MS_LOG(ERROR) << "Open file [" << filename << "] failed!";
  1586. return;
  1587. }
  1588. DumpDeviceTensorStore(graph_compiler_info, ofs);
  1589. DumpDataPrepareActor(actor_set->data_prepare_actor_, ofs);
  1590. DumpDSActors(actor_set->data_source_actors_, ofs);
  1591. DumpKernelActors(actor_set->kernel_actors_, ofs);
  1592. DumpSuperKernelActors(actor_set->super_kernel_actors_, ofs);
  1593. // The on input kernel actors are taken over by control actor in the control flow scene.
  1594. if ((graph_compiler_info.control_node_parser_ == nullptr) ||
  1595. (!graph_compiler_info.control_node_parser_->IsInited())) {
  1596. DumpNoInputKernelActors(actor_set->no_input_kernel_actors_, ofs);
  1597. }
  1598. DumpCopyActors(actor_set->copy_actors_, ofs);
  1599. DumpLoopCountActor(actor_set->loop_count_actor_, ofs);
  1600. DumpOutputActor(actor_set->output_actor_, ofs);
  1601. DumpControlActors(actor_set->control_actors_, ofs);
  1602. }
  1603. void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const {
  1604. ofs << "[Device tensor stores]\n";
  1605. for (const auto &graph : graph_compiler_info.graphs_) {
  1606. MS_EXCEPTION_IF_NULL(graph);
  1607. ofs << "\tgraph_id:" << graph->graph_id() << "\tis_executing_sink:" << graph->is_executing_sink()
  1608. << "\tis_loop_count_sink:" << graph->is_loop_count_sink()
  1609. << "\texecution_strategy:" << graph_compiler_info.strategy_ << "\n";
  1610. for (auto &value_node : graph->graph_value_nodes()) {
  1611. MS_EXCEPTION_IF_NULL(value_node);
  1612. if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
  1613. continue;
  1614. }
  1615. const auto &front_node = FetchFrontNodeByBackendNode(value_node, graph);
  1616. MS_EXCEPTION_IF_NULL(front_node);
  1617. const auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
  1618. ofs << "\t\tdevice tensor key:" << front_node->fullname_with_scope() << "\tvalue size:" << device_tensors.size()
  1619. << "\n";
  1620. for (const auto &device_tensor : device_tensors) {
  1621. MS_EXCEPTION_IF_NULL(device_tensor);
  1622. ofs << "\t\t\tdevice tensor value:" << device_tensor << "\tptr:" << device_tensor->GetPtr()
  1623. << "\tsize:" << device_tensor->GetSize() << "\toriginal_ref_count:" << device_tensor->original_ref_count()
  1624. << "\tdevice_type:" << device_tensor->DeviceType()
  1625. << "\tis_ptr_persisted:" << device_tensor->is_ptr_persisted() << "\n ";
  1626. }
  1627. }
  1628. for (auto &input_node : graph->input_nodes()) {
  1629. MS_EXCEPTION_IF_NULL(input_node);
  1630. if (!IsPersistentDeviceTensor(input_node)) {
  1631. continue;
  1632. }
  1633. const auto &sub_front_node = FetchFrontNodeByBackendNode(input_node, graph);
  1634. // The sub front nodes share the device tensor store with the root front node.
  1635. auto front_node = sub_front_node;
  1636. if (graph_compiler_info.control_node_parser_ != nullptr) {
  1637. front_node = graph_compiler_info.control_node_parser_->FetchRootGraphFrontNodeBySubFrontNode(sub_front_node);
  1638. }
  1639. const auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
  1640. MS_EXCEPTION_IF_NULL(front_node);
  1641. ofs << "\t\tdevice tensor key:" << front_node->fullname_with_scope() << "\tvalue size:" << device_tensors.size()
  1642. << "\n";
  1643. for (const auto &device_tensor : device_tensors) {
  1644. MS_EXCEPTION_IF_NULL(device_tensor);
  1645. ofs << "\t\t\tdevice tensor value:" << device_tensor << "\tptr:" << device_tensor->GetPtr()
  1646. << "\tsize:" << device_tensor->GetSize() << "\toriginal_ref_count:" << device_tensor->original_ref_count()
  1647. << "\tdevice_type:" << device_tensor->DeviceType()
  1648. << "\tis_ptr_persisted:" << device_tensor->is_ptr_persisted() << "\n ";
  1649. }
  1650. }
  1651. ofs << "\n";
  1652. }
  1653. }
  1654. } // namespace runtime
  1655. } // namespace mindspore