You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_scheduler.cc 151 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/framework/graph_scheduler.h"
  17. #include "runtime/framework/actor/memory_manager_actor.h"
  18. #include "runtime/framework/actor/debug_actor.h"
  19. #include "runtime/framework/actor/recorder_actor.h"
  20. #include "runtime/hardware/device_context_manager.h"
  21. #include "mindrt/src/actor/actormgr.h"
  22. #include "mindrt/include/async/async.h"
  23. #include "backend/session/anf_runtime_algorithm.h"
  24. #include "backend/optimizer/common/helper.h"
  25. #include "utils/config_manager.h"
  26. #include "utils/log_adapter.h"
  27. #include "utils/convert_utils.h"
  28. #include "utils/ms_context.h"
  29. #if !defined(_WIN32) && !defined(_WIN64)
  30. #include "utils/signal_util.h"
  31. #endif
  32. #include "common/trans.h"
  33. #include "debug/data_dump/dump_json_parser.h"
  34. #ifdef ENABLE_DUMP_IR
  35. #include "debug/rdr/recorder_manager.h"
  36. #endif
  37. #ifdef ENABLE_DEBUGGER
  38. #include "debug/debugger/debugger.h"
  39. #endif
  40. #include "profiler/device/profiling.h"
  41. #include "debug/common.h"
  42. namespace mindspore {
  43. namespace runtime {
  44. namespace {
  45. bool IsNeedInsertCopyActor(const DeviceContext *from_device_context, const DeviceContext *to_device_context) {
  46. MS_EXCEPTION_IF_NULL(from_device_context);
  47. MS_EXCEPTION_IF_NULL(to_device_context);
  48. if (from_device_context->GetDeviceAddressType() == to_device_context->GetDeviceAddressType()) {
  49. return false;
  50. } else {
  51. return true;
  52. }
  53. }
  54. void UpdateRefCount(DeviceTensor *const device_tensor, bool is_max_ref_count = false) {
  55. MS_EXCEPTION_IF_NULL(device_tensor);
  56. if (is_max_ref_count) {
  57. device_tensor->set_original_ref_count(SIZE_MAX);
  58. } else {
  59. device_tensor->IncreaseOriginalRefCount();
  60. }
  61. device_tensor->ResetRefCount();
  62. }
  63. // Update the reference count of device tensor by the output index of node.
  64. void UpdateRefCount(const AnfNodePtr &node, size_t output_idx, bool is_max_ref_count = false) {
  65. MS_EXCEPTION_IF_NULL(node);
  66. auto device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_idx, false);
  67. UpdateRefCount(device_tensor.get(), is_max_ref_count);
  68. }
  69. AnfNodePtr FetchFrontNodeByBackendNode(const AnfNodePtr &backend_node, const KernelGraphPtr &graph) {
  70. MS_EXCEPTION_IF_NULL(backend_node);
  71. MS_EXCEPTION_IF_NULL(graph);
  72. // Internal parameter ---> front node.
  73. auto front_node_with_index = graph->GetFrontNodeByInternalParameter(backend_node);
  74. if (front_node_with_index.first != nullptr) {
  75. return front_node_with_index.first;
  76. }
  77. auto front_node = graph->GetFrontAnfByBackendAnf(backend_node);
  78. // PyNative forward graph does not has front node, using backend node instead.
  79. if (front_node == nullptr) {
  80. front_node = backend_node;
  81. }
  82. return front_node;
  83. }
  84. KernelWithIndex FetchFrontNodeWithIndexByGraphOutput(const KernelWithIndex &output_with_index,
  85. const KernelGraphPtr &graph) {
  86. MS_EXCEPTION_IF_NULL(graph);
  87. auto front_node_with_index = graph->GetFrontNodeWithIndexByGraphOutput(output_with_index);
  88. // PyNative forward graph does not has front node, using backend node instead.
  89. if (front_node_with_index.first == nullptr) {
  90. front_node_with_index = output_with_index;
  91. }
  92. return front_node_with_index;
  93. }
  94. // The branch processing of PrepareDataForValueNode that value type is tensor.
  95. void PrepareDataForValueNodeTensor(const ValueNodePtr &node, const ValuePtr &node_value,
  96. const DeviceContext *device_context) {
  97. MS_EXCEPTION_IF_NULL(node);
  98. MS_EXCEPTION_IF_NULL(node_value);
  99. MS_EXCEPTION_IF_NULL(device_context);
  100. std::vector<TensorPtr> tensors;
  101. TensorValueToTensor(node_value, &tensors);
  102. for (size_t i = 0; i < tensors.size(); i++) {
  103. const auto &tensor = tensors[i];
  104. if (tensor == nullptr) {
  105. MS_LOG(WARNING) << "Tensor is null";
  106. return;
  107. }
  108. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, i, false);
  109. MS_EXCEPTION_IF_NULL(device_tensor);
  110. // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
  111. if (device_tensor->GetPtr() != nullptr) {
  112. return;
  113. }
  114. MS_LOG(INFO) << "Prepare device data for value node: " << node->fullname_with_scope() << ", output index: " << i;
  115. tensor->set_device_address(device_tensor);
  116. // Allocate device memory.
  117. if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
  118. MS_LOG(EXCEPTION) << "Device(id:" << device_context->device_context_key().device_id_
  119. << ") memory isn't enough and alloc failed, node name: " << node->fullname_with_scope()
  120. << ", alloc size: " << device_tensor->GetSize();
  121. }
  122. // Copy data from host tensor to device.
  123. if (!device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0), LongToSize(tensor->data().nbytes()),
  124. tensor->data_type(), tensor->data_c(), tensor->device_info().host_format_)) {
  125. MS_LOG(EXCEPTION) << "SyncHostToDevice failed, node name: " << node->fullname_with_scope();
  126. }
  127. }
  128. }
  129. // Prepare the device data for persistent device tensor of value node.
  130. void PrepareDataForValueNode(const ValueNodePtr &node, const DeviceContext *device_context) {
  131. MS_EXCEPTION_IF_NULL(node);
  132. MS_EXCEPTION_IF_NULL(device_context);
  133. auto &node_value = node->value();
  134. MS_EXCEPTION_IF_NULL(node_value);
  135. if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
  136. // The branch processing that value type is tensor.
  137. PrepareDataForValueNodeTensor(node, node_value, device_context);
  138. } else if (node_value->isa<StringImm>()) {
  139. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, 0, false);
  140. MS_EXCEPTION_IF_NULL(device_tensor);
  141. // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
  142. if (device_tensor->GetPtr() != nullptr) {
  143. return;
  144. }
  145. MS_LOG(INFO) << "Prepare device data for value node: " << node->fullname_with_scope();
  146. // Allocate device memory.
  147. if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
  148. MS_LOG(EXCEPTION) << "Device(id:" << device_context->device_context_key().device_id_
  149. << ") memory isn't enough and alloc failed, node name: " << node->fullname_with_scope()
  150. << ", alloc size: " << device_tensor->GetSize();
  151. }
  152. // Copy data from value to device.
  153. auto value = GetValue<std::string>(node_value);
  154. size_t tensor_size = value.size();
  155. ShapeVector shape = {1, SizeToLong(tensor_size)};
  156. if (!device_tensor->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) {
  157. MS_LOG(EXCEPTION) << "SyncHostToDevice failed, node name: " << node->fullname_with_scope();
  158. }
  159. }
  160. }
  161. // Prepare the device data for persistent device tensor of weight node from host tensor.
  162. void PrepareDataForWeightNode(const AnfNodePtr &backend_node, const AnfNodePtr &front_node, const TensorPtr &tensor,
  163. const DeviceContext *device_context) {
  164. MS_EXCEPTION_IF_NULL(backend_node);
  165. MS_EXCEPTION_IF_NULL(front_node);
  166. MS_EXCEPTION_IF_NULL(device_context);
  167. if (tensor == nullptr) {
  168. return;
  169. }
  170. auto device_tensor = AnfAlgo::GetMutableOutputAddr(backend_node, 0, false);
  171. auto host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
  172. // Use the device address of host tensor to set device tensor.
  173. if (host_tensor_address != device_tensor) {
  174. if (host_tensor_address == nullptr) {
  175. MS_EXCEPTION_IF_NULL(device_tensor);
  176. host_tensor_address = device_context->CreateDeviceAddress(nullptr, device_tensor->GetSize(),
  177. device_tensor->format(), device_tensor->type_id());
  178. tensor->set_device_address(host_tensor_address);
  179. UpdateRefCount(host_tensor_address.get(), true);
  180. }
  181. MS_EXCEPTION_IF_NULL(host_tensor_address);
  182. DeviceTensorStore::GetInstance().Insert(front_node.get(), host_tensor_address);
  183. if (host_tensor_address->DeviceType() == device_tensor->DeviceType()) {
  184. AnfAlgo::SetOutputAddr(host_tensor_address, 0, backend_node.get());
  185. } else {
  186. MS_LOG(INFO) << "The device type is not equal, host tensor type:" << host_tensor_address->DeviceType()
  187. << ", device tensor type:" << device_tensor->DeviceType();
  188. }
  189. }
  190. // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
  191. if (host_tensor_address->GetPtr() == nullptr) {
  192. MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope()
  193. << ", device type:" << host_tensor_address->DeviceType();
  194. // Allocate device memory and copy data from host tensor to device.
  195. if (!device_context->AllocateMemory(host_tensor_address.get(), host_tensor_address->GetSize())) {
  196. MS_LOG(EXCEPTION) << "Device(id:" << device_context->device_context_key().device_id_
  197. << ") memory isn't enough and alloc failed, node name: " << backend_node->fullname_with_scope();
  198. }
  199. if (!host_tensor_address->SyncHostToDevice(trans::GetRuntimePaddingShape(backend_node, 0),
  200. LongToSize(tensor->data().nbytes()), tensor->data_type(),
  201. tensor->data_c(), tensor->device_info().host_format_)) {
  202. MS_LOG(EXCEPTION) << "SyncHostToDevice failed, node name: " << backend_node->fullname_with_scope();
  203. }
  204. }
  205. // Allocate another device memory and copy data from host tensor to another device(if exist).
  206. const auto &device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
  207. if (device_tensors.size() > 1) {
  208. auto another_device_tensor = (device_tensors[0] == host_tensor_address) ? device_tensors[1] : device_tensors[0];
  209. MS_EXCEPTION_IF_NULL(another_device_tensor);
  210. auto another_device_type = another_device_tensor->DeviceType();
  211. const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
  212. {device::kDeviceTypeToName.at(another_device_type), device_context->device_context_key().device_id_});
  213. MS_EXCEPTION_IF_NULL(another_device_context);
  214. if (another_device_tensor->GetPtr() == nullptr) {
  215. if (!another_device_context->AllocateMemory(another_device_tensor.get(), another_device_tensor->GetSize())) {
  216. MS_LOG(EXCEPTION) << "Device(id:" << another_device_context->device_context_key().device_id_
  217. << ") memory isn't enough and alloc failed, node name: "
  218. << backend_node->fullname_with_scope();
  219. }
  220. }
  221. MS_LOG(INFO) << "Prepare device data for weight node:" << backend_node->fullname_with_scope()
  222. << ", device type:" << another_device_type;
  223. if (!Copy(another_device_tensor.get(), host_tensor_address.get())) {
  224. MS_LOG(EXCEPTION) << "Sync data error.";
  225. }
  226. }
  227. }
  228. // In control flow, all weight nodes associated with the host weight parameter need to use the same device tensor.
  229. void PrepareDataForControlWeightNode(
  230. const AnfNodePtr &node, const AnfNodePtr &front_node, const TensorPtr &tensor, const DeviceContext *device_context,
  231. const std::unordered_map<AnfNodePtr, std::vector<AnfNodePtr>> &host_parameter_to_weights = {}) {
  232. MS_EXCEPTION_IF_NULL(node);
  233. MS_EXCEPTION_IF_NULL(front_node);
  234. MS_EXCEPTION_IF_NULL(tensor);
  235. MS_EXCEPTION_IF_NULL(device_context);
  236. auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
  237. bool need_update_device_tensor_store = (device_tensors.size() == 0) ? true : false;
  238. for (auto &device_tensor : device_tensors) {
  239. if (device_tensor->GetPtr() == nullptr) {
  240. need_update_device_tensor_store = true;
  241. break;
  242. }
  243. }
  244. if (need_update_device_tensor_store) {
  245. PrepareDataForWeightNode(node, front_node, tensor, device_context);
  246. }
  247. const auto iter = host_parameter_to_weights.find(front_node);
  248. if (iter == host_parameter_to_weights.end()) {
  249. return;
  250. }
  251. // Fetch all the device tensors of host weight node and insert as the weight of other nodes.
  252. const auto &sub_front_nodes = host_parameter_to_weights.at(front_node);
  253. device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
  254. for (const auto &sub_front_node : sub_front_nodes) {
  255. for (const auto &device_tensor : device_tensors) {
  256. if (sub_front_node == nullptr) {
  257. MS_LOG(EXCEPTION) << "Front node is empty!";
  258. }
  259. DeviceTensorStore::GetInstance().Insert(sub_front_node.get(), device_tensor);
  260. }
  261. }
  262. }
  263. void PrepareDataForHostDataSourceActor(const std::unordered_map<AnfNodePtr, size_t> &data_node_position_map,
  264. const AnfNodePtr &node, const TensorPtr &tensor,
  265. std::vector<TensorPtr> *const host_tensors) {
  266. MS_EXCEPTION_IF_NULL(tensor);
  267. // Fill the host tensors for non weighted parameters.
  268. const auto &iter = data_node_position_map.find(node);
  269. if (iter == data_node_position_map.end()) {
  270. return;
  271. }
  272. (*host_tensors)[iter->second] = tensor;
  273. auto tensor_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
  274. auto device_address = AnfAlgo::GetMutableOutputAddr(node, 0, false);
  275. MS_EXCEPTION_IF_NULL(device_address);
  276. if ((tensor_address != nullptr) && (tensor_address->DeviceType() == device_address->DeviceType())) {
  277. AnfAlgo::SetOutputAddr(tensor_address, 0, node.get());
  278. }
  279. }
  280. void PrepareDataForInputData(const HostQueueDataSourceActor *host_data_source_actor, const AnfNodePtr &node,
  281. const TensorPtr &tensor, const DeviceContext *device_context,
  282. std::vector<TensorPtr> *const host_tensors) {
  283. MS_EXCEPTION_IF_NULL(tensor);
  284. // Fill the host tensors for non weighted parameters.
  285. if (host_data_source_actor != nullptr) {
  286. (*host_tensors)[host_data_source_actor->FetchNodePosition(node)] = tensor;
  287. }
  288. auto device_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
  289. if (device_address != nullptr) {
  290. AnfAlgo::SetOutputAddr(device_address, 0, node.get());
  291. return;
  292. }
  293. DeviceTensorPtr node_device_address = nullptr;
  294. if (!AnfAlgo::OutputAddrExist(node, 0, false)) {
  295. TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, 0);
  296. if (output_type_id == kTypeUnknown) {
  297. output_type_id = AnfAlgo::GetOutputInferDataType(node, 0);
  298. }
  299. size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(node, 0);
  300. auto new_device_address =
  301. device_context->CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(node, 0), output_type_id);
  302. AnfAlgo::SetOutputAddr(new_device_address, 0, node.get());
  303. node_device_address = new_device_address;
  304. } else {
  305. node_device_address = AnfAlgo::GetMutableOutputAddr(node, 0, false);
  306. }
  307. tensor->set_device_address(node_device_address);
  308. UpdateRefCount(node_device_address.get(), true);
  309. MS_EXCEPTION_IF_NULL(device_context);
  310. if (node_device_address->GetPtr() == nullptr &&
  311. !device_context->AllocateMemory(node_device_address.get(), node_device_address->GetSize())) {
  312. MS_LOG(EXCEPTION) << "Device(id:" << device_context->device_context_key().device_id_
  313. << ") memory isn't enough and alloc failed, node name: " << node->fullname_with_scope();
  314. }
  315. if (!node_device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0),
  316. LongToSize(tensor->data().nbytes()), tensor->data_type(), tensor->data_c(),
  317. tensor->device_info().host_format_)) {
  318. MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
  319. }
  320. }
  321. inline bool IsSingleOpActorSet(const ActorSet *actor_set) {
  322. MS_EXCEPTION_IF_NULL(actor_set);
  323. return actor_set->kernel_actors_.size() == 1;
  324. }
  325. bool RunInStepMode(const ActorSet *actor_set, const std::vector<TensorPtr> *input_tensors) {
  326. OpContext<DeviceTensor> op_context;
  327. // Step mode does not need sequential number.
  328. op_context.sequential_num_ = nullptr;
  329. // Trigger kernel actor running in the step execution strategy.
  330. if (IsSingleOpActorSet(actor_set)) {
  331. MS_EXCEPTION_IF_NULL(input_tensors);
  332. for (auto &kernel_actor : actor_set->kernel_actors_) {
  333. MS_EXCEPTION_IF_NULL(kernel_actor);
  334. kernel_actor->RunOpControlWithInputTensor(nullptr, &op_context, input_tensors);
  335. }
  336. return true;
  337. }
  338. std::vector<Promise<int>> result(1);
  339. op_context.results_ = &result;
  340. // Trigger data source actor running.
  341. for (auto &data_source_actor : actor_set->data_source_actors_) {
  342. MS_EXCEPTION_IF_NULL(data_source_actor);
  343. Async(data_source_actor->GetAID(), &DataSourceActor::FetchData, &op_context);
  344. }
  345. // Trigger no input kernel actor running.
  346. for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
  347. MS_EXCEPTION_IF_NULL(no_input_kernel_actor);
  348. Async(no_input_kernel_actor->GetAID(), &KernelActor::RunOpControl, nullptr, &op_context);
  349. }
  350. auto result_future = result[0].GetFuture();
  351. result_future.Wait();
  352. MsException::Instance().CheckException();
  353. return result_future.IsOK();
  354. }
  355. // Convert the actors vector by the actor set.
  356. std::vector<ActorReference> CollectActors(const ActorSet *actor_set) {
  357. MS_EXCEPTION_IF_NULL(actor_set);
  358. std::vector<ActorReference> actors;
  359. for (auto &data_source_actor : actor_set->data_source_actors_) {
  360. MS_EXCEPTION_IF_NULL(data_source_actor);
  361. (void)actors.emplace_back(static_cast<ActorReference>(data_source_actor));
  362. }
  363. for (auto &kernel_actor : actor_set->kernel_actors_) {
  364. MS_EXCEPTION_IF_NULL(kernel_actor);
  365. (void)actors.emplace_back(static_cast<ActorReference>(kernel_actor));
  366. }
  367. for (auto &switch_actor : actor_set->switch_actors_) {
  368. MS_EXCEPTION_IF_NULL(switch_actor);
  369. (void)actors.emplace_back(static_cast<ActorReference>(switch_actor));
  370. }
  371. for (auto &gather_actor : actor_set->gather_actors_) {
  372. MS_EXCEPTION_IF_NULL(gather_actor);
  373. (void)actors.emplace_back(static_cast<ActorReference>(gather_actor));
  374. }
  375. for (auto &copy_actor : actor_set->copy_actors_) {
  376. MS_EXCEPTION_IF_NULL(copy_actor);
  377. (void)actors.emplace_back(static_cast<ActorReference>(copy_actor));
  378. }
  379. if (actor_set->loop_count_actor_ != nullptr) {
  380. (void)actors.emplace_back(static_cast<ActorReference>(actor_set->loop_count_actor_));
  381. }
  382. if (actor_set->output_actor_ != nullptr) {
  383. (void)actors.emplace_back(static_cast<ActorReference>(actor_set->output_actor_));
  384. }
  385. return actors;
  386. }
  387. void ClearNodeInfo(const KernelGraphPtr graph) {
  388. MS_EXCEPTION_IF_NULL(graph);
  389. // Clear input parameter device tensor and device tensor store.
  390. for (const auto &input_node : graph->input_nodes()) {
  391. MS_EXCEPTION_IF_NULL(input_node);
  392. if (!input_node->isa<Parameter>()) {
  393. continue;
  394. }
  395. auto parameter = input_node->cast<ParameterPtr>();
  396. MS_EXCEPTION_IF_NULL(parameter);
  397. parameter->DecreaseUsedGraphCount();
  398. // Only the parameter has no graph used, then clear the device tensor.
  399. if (parameter->used_graph_count() != 0) {
  400. continue;
  401. }
  402. auto front_input_node = FetchFrontNodeByBackendNode(input_node, graph);
  403. DeviceTensorStore::GetInstance().Remove(front_input_node.get());
  404. size_t output_num = AnfAlgo::GetOutputTensorNum(input_node);
  405. for (size_t index = 0; index < output_num; ++index) {
  406. if (AnfAlgo::OutputAddrExist(input_node, index)) {
  407. AnfAlgo::SetOutputAddr(nullptr, index, input_node.get());
  408. }
  409. }
  410. }
  411. // Clear input value node device tensor and device tensor store.
  412. for (const auto &value_node : graph->graph_value_nodes()) {
  413. auto front_value_node = FetchFrontNodeByBackendNode(value_node, graph);
  414. DeviceTensorStore::GetInstance().Remove(front_value_node.get());
  415. if (AnfAlgo::OutputAddrExist(value_node, 0)) {
  416. AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get());
  417. }
  418. }
  419. // Clear cnode device tensor.
  420. for (const auto &cnode : graph->execution_order()) {
  421. size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
  422. for (size_t index = 0; index < output_num; ++index) {
  423. if (AnfAlgo::OutputAddrExist(cnode, index)) {
  424. AnfAlgo::SetOutputAddr(nullptr, index, cnode.get());
  425. }
  426. }
  427. }
  428. }
  429. #if !defined(_WIN32) && !defined(_WIN64)
  430. void IntHandler(int, siginfo_t *, void *) {
  431. int this_pid = getpid();
  432. MS_LOG(WARNING) << "Process " << this_pid << " receive KeyboardInterrupt signal.";
  433. (void)kill(this_pid, SIGTERM);
  434. }
  435. #endif
  436. } // namespace
  437. GraphCompilerInfo::~GraphCompilerInfo() { GraphScheduler::GetInstance().Clear(name_, graphs_); }
  438. void GraphScheduler::Clear(const ActorInfo &actor_info, const std::vector<KernelGraphPtr> &graphs) {
  439. // Terminate the actors of actor info.
  440. if (actors_.count(actor_info) > 0) {
  441. auto actorMgr = ActorMgr::GetActorMgrRef();
  442. MS_EXCEPTION_IF_NULL(actorMgr);
  443. auto actor_set = actors_[actor_info];
  444. auto base_actors = CollectActors(actor_set.get());
  445. for (auto &base_actor : base_actors) {
  446. actorMgr->Terminate(base_actor->GetAID());
  447. }
  448. }
  449. // Clear device tensor and device tensor store.
  450. for (auto &graph : graphs) {
  451. ClearNodeInfo(graph);
  452. }
  453. // Clear global maps of actor info.
  454. (void)actors_.erase(actor_info);
  455. (void)actor_to_host_queue_.erase(actor_info);
  456. }
  457. void GraphScheduler::Clear() {
  458. // Terminate all actors.
  459. auto actorMgr = ActorMgr::GetActorMgrRef();
  460. MS_EXCEPTION_IF_NULL(actorMgr);
  461. actorMgr->Finalize();
  462. // Clear the member of DeviceTensorStore.
  463. DeviceTensorStore::GetInstance().Clear();
  464. // Clear global maps.
  465. actors_.clear();
  466. actor_name_to_actor_.clear();
  467. actor_to_host_queue_.clear();
  468. }
  469. using DataArrowLinkFunc = void (GraphScheduler::*)(AbstractActor *const, KernelActor *const, const KernelWithIndex &,
  470. const KernelWithIndex &, const KernelGraphPtr &);
  471. static std::map<KernelTransformType, DataArrowLinkFunc> kKernelTypeToLinkFunc;
  472. void GraphScheduler::Initialize() {
  473. if (init_) {
  474. return;
  475. }
  476. init_ = true;
  477. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kDeviceDataSourceActor,
  478. &GraphScheduler::LinkDataArrowForDeviceDSActor);
  479. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kHostDataSourceActor,
  480. &GraphScheduler::LinkDataArrowForHostDSActor);
  481. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kKernelActor, &GraphScheduler::LinkDataArrowForKernelActor);
  482. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kDeviceTensorStore,
  483. &GraphScheduler::LinkDataArrowForDeviceTensorStore);
  484. (void)kKernelTypeToLinkFunc.emplace(KernelTransformType::kInternalParameter,
  485. &GraphScheduler::LinkDataArrowForInternalParameter);
  486. // Create the thread pool of actor runtime and Set the OMP_NUM_THREADS env.
  487. size_t actor_thread_num = 0;
  488. size_t OMP_thread_num = 0;
  489. size_t max_thread_num = 0;
  490. ComputeThreadNums(&actor_thread_num, &OMP_thread_num, &max_thread_num);
  491. auto actor_manager = ActorMgr::GetActorMgrRef();
  492. MS_EXCEPTION_IF_NULL(actor_manager);
  493. actor_manager->Initialize(true, actor_thread_num, max_thread_num);
  494. std::string OMP_env = std::to_string(OMP_thread_num);
  495. (void)common::SetEnv("OMP_NUM_THREADS", OMP_env.c_str(), 0);
  496. auto OMP_thread_num_used = common::GetEnv("OMP_NUM_THREADS");
  497. MS_LOG(INFO) << "The actor thread number: " << actor_thread_num
  498. << ", the computed OMP thread number : " << OMP_thread_num
  499. << ", the used OMP thread number : " << OMP_thread_num_used;
  500. BuildAndScheduleGlobalActor();
  501. }
  502. void GraphScheduler::BuildAndScheduleGlobalActor() {
  503. auto actorMgr = ActorMgr::GetActorMgrRef();
  504. MS_EXCEPTION_IF_NULL(actorMgr);
  505. // Create and schedule memory manager actor.
  506. auto memory_manager_actor = std::make_shared<MemoryManagerActor>();
  507. MS_EXCEPTION_IF_NULL(memory_manager_actor);
  508. memory_manager_aid_ = memory_manager_actor->GetAID();
  509. auto base_actor = static_cast<ActorReference>(memory_manager_actor);
  510. // Bind single thread to response to memory alloc and free quickly.
  511. (void)actorMgr->Spawn(base_actor, false);
  512. // Create and schedule recorder actor.
  513. auto recorder_actor = std::make_shared<RecorderActor>();
  514. MS_EXCEPTION_IF_NULL(recorder_actor);
  515. recorder_aid_ = &(recorder_actor->GetAID());
  516. auto base_recorder_actor = static_cast<ActorReference>(recorder_actor);
  517. (void)actorMgr->Spawn(base_recorder_actor, true);
  518. // Create and schedule debug actor.
  519. bool debugger_actor_need = DumpJsonParser::GetInstance().e2e_dump_enabled();
  520. #ifdef ENABLE_DEBUGGER
  521. if (Debugger::GetInstance()->DebuggerBackendEnabled()) {
  522. debugger_actor_need = true;
  523. }
  524. #endif
  525. if (debugger_actor_need) {
  526. auto debug_actor = std::make_shared<DebugActor>();
  527. MS_EXCEPTION_IF_NULL(debug_actor);
  528. debug_aid_ = &(debug_actor->GetAID());
  529. auto base_debug_actor = static_cast<ActorReference>(debug_actor);
  530. (void)actorMgr->Spawn(base_debug_actor, true);
  531. }
  532. }
  533. ActorSet *GraphScheduler::Transform(const GraphCompilerInfo &graph_compiler_info) {
  534. MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor begin.";
  535. if (graph_compiler_info.graphs_.size() == 0) {
  536. MS_LOG(EXCEPTION) << "The number of graphs is zero.";
  537. }
  538. if (graph_compiler_info.graphs_.size() != graph_compiler_info.device_contexts_.size()) {
  539. MS_LOG(EXCEPTION) << "The number of graphs is not equal to the number of device contexts.";
  540. }
  541. PersistDeviceTensor(graph_compiler_info);
  542. auto strategy = graph_compiler_info.strategy_;
  543. const auto &actor_set = Build(graph_compiler_info);
  544. CacheGraphOutputToActor(graph_compiler_info);
  545. Link(actor_set.get(), graph_compiler_info);
  546. // The copy actors are built in the link, so need push into the actor set after link.
  547. actor_set->copy_actors_ = copy_actors_;
  548. (void)actors_.emplace(actor_set->name_, actor_set);
  549. DumpActor(actor_set.get(), graph_compiler_info);
  550. if (!CheckActorValid(actor_set.get(), strategy)) {
  551. MS_LOG(EXCEPTION) << "The actor set of " << graph_compiler_info.name_ << " is invalid.";
  552. }
  553. MS_LOG(INFO) << "Graph(" << graph_compiler_info.name_ << ") transforms actor end.";
  554. // Local maps and vectors clear.
  555. graph_output_to_actor_.clear();
  556. front_node_to_actor_.clear();
  557. copy_actors_.clear();
  558. return actor_set.get();
  559. }
  560. void GraphScheduler::Schedule(const ActorSet *actor_set) {
  561. MS_EXCEPTION_IF_NULL(actor_set);
  562. auto actors = CollectActors(actor_set);
  563. // Schedule actors.
  564. auto actorMgr = ActorMgr::GetActorMgrRef();
  565. MS_EXCEPTION_IF_NULL(actorMgr);
  566. for (auto actor : actors) {
  567. (void)actorMgr->Spawn(actor);
  568. }
  569. }
  570. void GraphScheduler::PrepareRun(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info,
  571. const std::vector<std::vector<TensorPtr>> &input_tensors) {
  572. MS_EXCEPTION_IF_NULL(actor_set);
  573. std::vector<TensorPtr> host_tensors;
  574. std::string actor_name = actor_set->name_ + "_HostDSActor";
  575. const auto &host_data_source_actor = dynamic_cast<HostQueueDataSourceActor *>(FetchActor(actor_name));
  576. if (host_data_source_actor != nullptr) {
  577. host_tensors.resize(host_data_source_actor->data_nodes_.size());
  578. }
  579. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  580. const auto &graph = graph_compiler_info.graphs_[i];
  581. const auto &device_context = graph_compiler_info.device_contexts_[i];
  582. MS_EXCEPTION_IF_NULL(graph);
  583. // 1.Prepare the data of device tensor store(value nodes of graph).
  584. for (const auto &value_node : graph->graph_value_nodes()) {
  585. if (AnfAlgo::OutputAddrExist(value_node, 0)) {
  586. PrepareDataForValueNode(value_node, device_context);
  587. }
  588. }
  589. // 1.Prepare the data of device tensor store(weights of graph), and fill host tensors for non weighted parameters.
  590. const auto &input_nodes = graph->input_nodes();
  591. const auto &tensors = input_tensors[i];
  592. for (size_t j = 0; j < input_nodes.size(); ++j) {
  593. const auto &input_node = input_nodes[j];
  594. const auto &input_tensor = tensors[j];
  595. MS_EXCEPTION_IF_NULL(input_node);
  596. if (IsPersistentDeviceTensor(input_node)) {
  597. // Prepare the device data for weights.
  598. const auto front_node = FetchFrontNodeByBackendNode(input_node, graph);
  599. PrepareDataForWeightNode(input_node, front_node, input_tensor, device_context);
  600. } else if (IsHostQueueDSActor(input_node, graph, graph_compiler_info.origin_parameters_order_,
  601. graph_compiler_info.strategy_)) {
  602. MS_EXCEPTION_IF_NULL(host_data_source_actor);
  603. PrepareDataForHostDataSourceActor(host_data_source_actor->data_node_position_map_, input_node, input_tensor,
  604. &host_tensors);
  605. }
  606. }
  607. }
  608. // 2.Prepare the continuous memory for communication kernel.
  609. if (actor_set->loop_count_actor_ != nullptr) {
  610. auto alloc_list_list = actor_set->loop_count_actor_->continuous_memory_alloc_list_list_;
  611. auto size_list_list = actor_set->loop_count_actor_->size_list_list_;
  612. auto total_size_list = actor_set->loop_count_actor_->total_size_list_;
  613. auto device_contexts = actor_set->loop_count_actor_->device_contexts_;
  614. if ((alloc_list_list.size() != size_list_list.size()) || (size_list_list.size() != total_size_list.size()) ||
  615. (total_size_list.size() != device_contexts.size())) {
  616. MS_LOG(EXCEPTION)
  617. << "The size of alloc_list_list, size_list_list, total_size_list and device_contexts are not equal.";
  618. }
  619. for (size_t i = 0; i < alloc_list_list.size(); ++i) {
  620. auto &alloc_list = alloc_list_list[i];
  621. auto &size_list = size_list_list[i];
  622. auto &total_size = total_size_list[i];
  623. auto &device_context = device_contexts[i];
  624. if (!device_context->AllocateContinuousMemory(alloc_list, total_size, size_list)) {
  625. MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, alloc size: " << total_size;
  626. }
  627. }
  628. }
  629. // 3.Prepare the data which belongs to control node.
  630. PrepareDataForControlNode(host_data_source_actor, graph_compiler_info.control_node_parser_,
  631. graph_compiler_info.origin_parameters_order_, input_tensors.back(), &host_tensors);
  632. // 4.Prepare the data of host tensor queue(non weighted parameters of graph).
  633. if (host_data_source_actor != nullptr) {
  634. const auto &host_tensor_queue = FetchHostQueue(actor_set->name_);
  635. MS_EXCEPTION_IF_NULL(host_tensor_queue);
  636. host_tensor_queue->Push(host_tensors);
  637. }
  638. }
  639. void GraphScheduler::PrepareRunOp(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info,
  640. const std::vector<std::vector<TensorPtr>> &input_tensors) {
  641. MS_EXCEPTION_IF_NULL(actor_set);
  642. std::vector<TensorPtr> host_tensors;
  643. std::string actor_name = actor_set->name_ + "_HostDSActor";
  644. const auto &host_data_source_actor = dynamic_cast<HostQueueDataSourceActor *>(FetchActor(actor_name));
  645. if (host_data_source_actor != nullptr) {
  646. host_tensors.resize(host_data_source_actor->data_nodes_.size());
  647. }
  648. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  649. const auto &graph = graph_compiler_info.graphs_[i];
  650. const auto &device_context = graph_compiler_info.device_contexts_[i];
  651. MS_EXCEPTION_IF_NULL(graph);
  652. // 1.Prepare the data of device tensor store(value nodes of graph).
  653. for (const auto &value_node : graph->graph_value_nodes()) {
  654. if (AnfAlgo::OutputAddrExist(value_node, 0)) {
  655. PrepareDataForValueNode(value_node, device_context);
  656. }
  657. }
  658. // 2.Prepare the data of device tensor store(weights of graph), and fill host tensors for non weighted parameters.
  659. const auto &input_nodes = graph->input_nodes();
  660. const auto &tensors = input_tensors[i];
  661. for (size_t j = 0; j < input_nodes.size(); ++j) {
  662. const auto &input_node = input_nodes[j];
  663. const auto &input_tensor = tensors[j];
  664. MS_EXCEPTION_IF_NULL(input_node);
  665. if (IsPersistentDeviceTensor(input_node)) {
  666. // Prepare the device data for weights.
  667. PrepareDataForWeightNode(input_node, input_node, input_tensor, device_context);
  668. } else {
  669. PrepareDataForInputData(host_data_source_actor, input_node, input_tensor, device_context, &host_tensors);
  670. }
  671. }
  672. }
  673. // 3.Prepare the data of host tensor queue(non weighted parameters of graph).
  674. if (host_data_source_actor != nullptr) {
  675. const auto &host_tensor_queue = FetchHostQueue(actor_set->name_);
  676. MS_EXCEPTION_IF_NULL(host_tensor_queue);
  677. host_tensor_queue->Push(host_tensors);
  678. }
  679. }
  680. void GraphScheduler::PrepareDataForControlNode(HostQueueDataSourceActor *const host_data_source_actor,
  681. const ControlNodeParserPtr &control_node_parser,
  682. const std::vector<AnfNodePtr> &origin_parameters,
  683. const std::vector<TensorPtr> &tensors,
  684. std::vector<TensorPtr> *const host_tensors) {
  685. const auto &control_node_parameters = control_node_parser->GetControlNodeParameter();
  686. for (size_t j = 0; j < control_node_parameters.size(); ++j) {
  687. const auto &input_node = control_node_parameters[j];
  688. const auto &input_tensor = tensors[j];
  689. MS_EXCEPTION_IF_NULL(input_node);
  690. if (IsPersistentDeviceTensor(input_node)) {
  691. const auto &front_to_backend_parameters = control_node_parser->front_to_backend_parameters_;
  692. const auto &iter = front_to_backend_parameters.find(input_node);
  693. if (iter == front_to_backend_parameters.end()) {
  694. MS_LOG(EXCEPTION) << "Cannot find backend node for weight parameter:"
  695. << AnfAlgo::GetNodeDebugString(input_node);
  696. }
  697. const auto &node_with_context = iter->second;
  698. PrepareDataForControlWeightNode(node_with_context.first, input_node, input_tensor, node_with_context.second,
  699. control_node_parser->host_parameter_to_weights_);
  700. } else if (find(origin_parameters.begin(), origin_parameters.end(), input_node) != origin_parameters.end()) {
  701. const auto &iter = host_data_source_actor->data_node_position_map_.find(input_node);
  702. if (iter == host_data_source_actor->data_node_position_map_.end()) {
  703. MS_LOG(EXCEPTION) << "Cannot find node" << AnfAlgo::GetNodeDebugString(input_node) << " in data source actor";
  704. }
  705. const size_t pos = iter->second;
  706. const AnfNodePtr &backend_node = host_data_source_actor->data_nodes_[pos];
  707. (*host_tensors)[pos] = input_tensor;
  708. auto device_address = std::dynamic_pointer_cast<DeviceTensor>(input_tensor->device_address());
  709. if (device_address != nullptr) {
  710. AnfAlgo::SetOutputAddr(device_address, 0, backend_node.get());
  711. }
  712. }
  713. }
  714. for (const auto &value_node_with_context : control_node_parser->front_value_nodes_) {
  715. if (AnfAlgo::OutputAddrExist(value_node_with_context.first, 0)) {
  716. PrepareDataForValueNode(value_node_with_context.first->cast<ValueNodePtr>(), value_node_with_context.second);
  717. }
  718. }
  719. }
  720. bool GraphScheduler::Run(const ActorSet *actor_set, GraphExecutionStrategy strategy,
  721. const std::vector<TensorPtr> *input_tensors) {
  722. MS_EXCEPTION_IF_NULL(actor_set);
  723. #if !defined(_WIN32) && !defined(_WIN64)
  724. SignalGuard sg(IntHandler);
  725. #endif
  726. if (strategy == GraphExecutionStrategy::kStep) {
  727. return RunInStepMode(actor_set, input_tensors);
  728. }
  729. // Construct OpContext.
  730. OpContext<DeviceTensor> op_context;
  731. uuids::uuid sequential_num;
  732. std::vector<Promise<int>> result(1);
  733. op_context.sequential_num_ = &sequential_num;
  734. op_context.results_ = &result;
  735. // Trigger data source actor running.
  736. for (auto &data_source_actor : actor_set->data_source_actors_) {
  737. MS_EXCEPTION_IF_NULL(data_source_actor);
  738. Async(data_source_actor->GetAID(), &DataSourceActor::FetchData, &op_context);
  739. }
  740. // Trigger no input kernel actor running.
  741. for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
  742. MS_EXCEPTION_IF_NULL(no_input_kernel_actor);
  743. Async(no_input_kernel_actor->GetAID(), &KernelActor::RunOpControl, nullptr, &op_context);
  744. }
  745. // Trigger output actor running when there are no data source actor and kernel actor.
  746. if ((actor_set->data_source_actors_.size() == 0) && (actor_set->kernel_actors_.size() == 0)) {
  747. MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
  748. Async(actor_set->output_actor_->GetAID(), &OutputActor::CollectLoopCount, actor_set->output_actor_->loop_count_,
  749. &op_context);
  750. }
  751. // Get the run result.
  752. auto result_future = result[0].GetFuture();
  753. result_future.Wait();
  754. MsException::Instance().CheckException();
  755. return result_future.IsOK();
  756. }
  757. ActorSet *GraphScheduler::Fetch(const ActorInfo &actor_info) const {
  758. auto iter = actors_.find(actor_info);
  759. if (iter != actors_.end()) {
  760. return iter->second.get();
  761. } else {
  762. MS_LOG(ERROR) << "Can't find the actors map of " << actor_info;
  763. return nullptr;
  764. }
  765. }
  766. ActorSetPtr GraphScheduler::Build(const GraphCompilerInfo &graph_compiler_info) {
  767. auto actor_set = std::make_shared<ActorSet>(graph_compiler_info.name_);
  768. MS_EXCEPTION_IF_NULL(actor_set);
  769. auto host_queue = std::make_shared<HostTensorQueue>();
  770. (void)actor_to_host_queue_.emplace(actor_set->name_, host_queue);
  771. actor_set->data_source_actors_ = BuildDataSourceActor(graph_compiler_info, host_queue);
  772. actor_set->kernel_actors_ = BuildKernelActor(graph_compiler_info);
  773. actor_set->loop_count_actor_ = BuildLoopCountActor(graph_compiler_info);
  774. actor_set->output_actor_ = BuildOutputActor(graph_compiler_info);
  775. actor_set->switch_actors_ = BuildSwitchActor(graph_compiler_info);
  776. actor_set->gather_actors_ = BuildGatherActor(graph_compiler_info);
  777. return actor_set;
  778. }
  779. void GraphScheduler::CacheGraphOutputToActor(const GraphCompilerInfo &graph_compiler_info) {
  780. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
  781. return;
  782. }
  783. for (const auto &graph : graph_compiler_info.graphs_) {
  784. MS_EXCEPTION_IF_NULL(graph);
  785. auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
  786. for (const auto &output_with_index : outputs) {
  787. auto output_kernel = output_with_index.first;
  788. MS_EXCEPTION_IF_NULL(output_kernel);
  789. auto origin_output_with_index = graph->GetFrontNodeWithIndexByGraphOutput(output_with_index);
  790. if (origin_output_with_index.first == nullptr) {
  791. MS_LOG(WARNING) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
  792. << " with index: " << output_with_index.second << " has no actor.";
  793. continue;
  794. }
  795. auto actor_output_index = output_with_index.second;
  796. OpActor<DeviceTensor> *actor = nullptr;
  797. if (IsKernelActor(output_kernel, graph_compiler_info.strategy_)) {
  798. actor = FetchActor(output_kernel->fullname_with_scope());
  799. } else if (IsDeviceQueueDSActor(output_kernel, graph_compiler_info.strategy_)) {
  800. std::string actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
  801. actor = FetchActor(actor_name);
  802. } else if (IsHostQueueDSActor(output_kernel, graph, graph_compiler_info.origin_parameters_order_,
  803. graph_compiler_info.strategy_)) {
  804. actor = FetchActor(graph_compiler_info.name_ + "_HostDSActor");
  805. const auto &host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(actor);
  806. MS_EXCEPTION_IF_NULL(host_ds_actor);
  807. // Get the position of output kernel in the data source actor.
  808. actor_output_index = host_ds_actor->FetchNodePosition(output_kernel);
  809. } else if (IsPersistentDeviceTensor(output_kernel)) {
  810. MS_LOG(INFO) << "The graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
  811. << " is device tensor store.";
  812. continue;
  813. } else {
  814. MS_LOG(INFO) << "Ignore the internal parameter node:" << output_kernel->DebugString();
  815. continue;
  816. }
  817. MS_EXCEPTION_IF_NULL(actor);
  818. MS_LOG(INFO) << "Cache the graph " << graph->graph_id() << " output node:" << output_kernel->fullname_with_scope()
  819. << " with index: " << output_with_index.second << " to actor:" << actor->GetAID().Name()
  820. << " with index:" << actor_output_index
  821. << ", from front node:" << origin_output_with_index.first->fullname_with_scope()
  822. << " with index: " << origin_output_with_index.second;
  823. (void)graph_output_to_actor_.emplace(origin_output_with_index,
  824. GraphOutputPair(dynamic_cast<AbstractActor *>(actor), actor_output_index));
  825. }
  826. }
  827. }
  828. void GraphScheduler::Link(ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) {
  829. MS_EXCEPTION_IF_NULL(actor_set);
  830. std::vector<KernelActor *> auto_monad_actors;
  831. std::vector<CNodePtr> communication_nodes;
  832. const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> auto_monad_prims = {
  833. prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad};
  834. // Foreach the execution order to link the actors.
  835. for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) {
  836. const auto &graph = graph_compiler_info.graphs_[index];
  837. MS_EXCEPTION_IF_NULL(graph);
  838. auto execution_order = graph->execution_order();
  839. for (auto &kernel : execution_order) {
  840. if (AnfAlgo::IsCommunicationOp(kernel)) {
  841. (void)communication_nodes.emplace_back(kernel);
  842. }
  843. if (IsSkippedKernelActor(kernel) || (!IsKernelActor(kernel, graph_compiler_info.strategy_))) {
  844. continue;
  845. }
  846. const auto &kernel_actor = dynamic_cast<KernelActor *>(FetchActor(kernel->fullname_with_scope()));
  847. MS_EXCEPTION_IF_NULL(kernel_actor);
  848. for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
  849. auto input_node = AnfAlgo::GetInputNode(kernel, i);
  850. // Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
  851. if (AnfAlgo::IsOneOfPrimitiveCNode(input_node, auto_monad_prims)) {
  852. LinkControlArrowByAutoMonad(kernel_actor, input_node, graph);
  853. }
  854. if (HasAbstractMonad(input_node)) {
  855. (void)auto_monad_actors.emplace_back(kernel_actor);
  856. continue; // No data arrow for monad input.
  857. }
  858. KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false);
  859. KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);
  860. // The gather of linking data arrows of kernel by the different from kernel type.
  861. LinkDataArrow(kernel_actor, graph_compiler_info, graph, from_kernel_with_output_idx, to_kernel_with_input_idx);
  862. }
  863. }
  864. // Link the control arrows for allreduce kernel by the send/recv nodes in the kernel graph.
  865. LinkControlArrowBySendRecvNodes(graph);
  866. }
  867. // Link the control arrows by the communication nodes to ensure communication nodes running order.
  868. LinkControlArrowByCommunicationNode(communication_nodes, graph_compiler_info);
  869. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) {
  870. // Link the arrow by control node.
  871. LinkArrowByControlNode(graph_compiler_info, actor_set);
  872. }
  873. // Auto monad actor may modify the device tensor store.
  874. LinkDeviceTensorStoreForAutoMonadActor(auto_monad_actors);
  875. // BuildNoInputKernelActor depends on whether kernel actors have input, so must be behind the link of kernel actors.
  876. actor_set->no_input_kernel_actors_ = BuildNoInputKernelActor(actor_set, graph_compiler_info.strategy_);
  877. // Link the control arrows of loop count actor, which depends on the no input kernel actors.
  878. LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), actor_set,
  879. graph_compiler_info.control_node_parser_);
  880. // Link the output result arrows for output actors.
  881. LinkOutputResultArrowForOutputActor(actor_set->output_actor_.get(), graph_compiler_info);
  882. }
  883. std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const GraphCompilerInfo &graph_compiler_info,
  884. const HostTensorQueuePtr &host_queue) {
  885. std::vector<DataSourceActorPtr> data_source_actors;
  886. HostQueueDSActorPtr host_queue_ds_actor = nullptr;
  887. size_t data_node_position = 0;
  888. std::unordered_map<AnfNodePtr, size_t> front_node_position_temp_map;
  889. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  890. const auto &graph = graph_compiler_info.graphs_[i];
  891. const auto &device_context = graph_compiler_info.device_contexts_[i];
  892. MS_EXCEPTION_IF_NULL(graph);
  893. // Build host queue data source actor.
  894. const std::vector<AnfNodePtr> &input_nodes = graph->input_nodes();
  895. for (size_t j = 0; j < input_nodes.size(); j++) {
  896. const auto &input_node = input_nodes[j];
  897. MS_EXCEPTION_IF_NULL(input_node);
  898. if (IsHostQueueDSActor(input_node, graph, graph_compiler_info.origin_parameters_order_,
  899. graph_compiler_info.strategy_)) {
  900. if (host_queue_ds_actor == nullptr) {
  901. auto actor_name = graph_compiler_info.name_ + "_HostDSActor";
  902. MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
  903. host_queue_ds_actor = std::make_shared<HostQueueDataSourceActor>(actor_name, 1, memory_manager_aid_, nullptr,
  904. nullptr, host_queue);
  905. InsertActor(host_queue_ds_actor.get());
  906. (void)data_source_actors.emplace_back(host_queue_ds_actor);
  907. }
  908. const auto &front_node = FetchFrontNodeByBackendNode(input_node, graph);
  909. // In the scenario where multiple backend nodes correspond to the same front node, only the first backend node
  910. // is saved in the host queue data source actor.
  911. if (front_node_position_temp_map.count(front_node) > 0) {
  912. (void)host_queue_ds_actor->data_node_position_map_.emplace(input_node,
  913. front_node_position_temp_map[front_node]);
  914. continue;
  915. }
  916. (void)host_queue_ds_actor->data_nodes_.emplace_back(input_node);
  917. (void)host_queue_ds_actor->device_contexts_.emplace_back(device_context);
  918. (void)host_queue_ds_actor->data_node_position_map_.emplace(input_node, data_node_position);
  919. (void)front_node_position_temp_map.emplace(front_node, data_node_position);
  920. data_node_position++;
  921. }
  922. }
  923. // Build device queue data source actor.
  924. const auto &execution_order = graph->execution_order();
  925. const auto &iter =
  926. std::find_if(execution_order.begin(), execution_order.end(), [&graph_compiler_info](const CNodePtr &node) {
  927. return IsDeviceQueueDSActor(node, graph_compiler_info.strategy_);
  928. });
  929. if (iter != execution_order.end()) {
  930. auto actor_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
  931. MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
  932. auto device_queue_ds_actor = std::make_shared<DeviceQueueDataSourceActor>(
  933. actor_name, 1, device_context, memory_manager_aid_, debug_aid_, recorder_aid_);
  934. MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
  935. InsertActor(device_queue_ds_actor.get());
  936. (void)data_source_actors.emplace_back(device_queue_ds_actor);
  937. device_queue_ds_actor->data_kernel_ = *iter;
  938. device_queue_ds_actor->kernel_info_ = dynamic_cast<device::KernelInfo *>((*iter)->kernel_info());
  939. }
  940. }
  941. const auto &front_to_backend_parameter = graph_compiler_info.control_node_parser_->front_to_backend_parameters_;
  942. // Initialize the parameter in the control node, first get all the front parameters in the control node, then find
  943. // the corresponding backend parameter from the map, and insert it into the host data source actor
  944. std::vector<AnfNodePtr> control_node_parameters = graph_compiler_info.control_node_parser_->GetControlNodeParameter();
  945. for (const auto parameter : control_node_parameters) {
  946. if (IsPersistentDeviceTensor(parameter)) {
  947. continue;
  948. }
  949. auto backend_iter = front_to_backend_parameter.find(parameter);
  950. if (backend_iter == front_to_backend_parameter.end()) {
  951. MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(parameter);
  952. }
  953. if (host_queue_ds_actor == nullptr) {
  954. auto actor_name = graph_compiler_info.name_ + "_HostDSActor";
  955. MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
  956. host_queue_ds_actor =
  957. std::make_shared<HostQueueDataSourceActor>(actor_name, 1, memory_manager_aid_, nullptr, nullptr, host_queue);
  958. InsertActor(host_queue_ds_actor.get());
  959. (void)data_source_actors.emplace_back(host_queue_ds_actor);
  960. }
  961. const auto &backend_node = backend_iter->second.first;
  962. auto iter = find(host_queue_ds_actor->data_nodes_.begin(), host_queue_ds_actor->data_nodes_.end(), backend_node);
  963. if (iter != host_queue_ds_actor->data_nodes_.end()) {
  964. (void)host_queue_ds_actor->data_node_position_map_.emplace(parameter,
  965. iter - host_queue_ds_actor->data_nodes_.begin());
  966. } else {
  967. (void)host_queue_ds_actor->data_node_position_map_.emplace(parameter, host_queue_ds_actor->data_nodes_.size());
  968. (void)host_queue_ds_actor->data_nodes_.emplace_back(backend_iter->second.first);
  969. (void)host_queue_ds_actor->device_contexts_.emplace_back(backend_iter->second.second);
  970. }
  971. }
  972. return data_source_actors;
  973. }
  974. std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const GraphCompilerInfo &graph_compiler_info) {
  975. std::vector<KernelActorPtr> kernel_actors;
  976. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  977. const auto &graph = graph_compiler_info.graphs_[i];
  978. const auto &device_context = graph_compiler_info.device_contexts_[i];
  979. MS_EXCEPTION_IF_NULL(graph);
  980. auto execution_order = graph->execution_order();
  981. // Single op graph in step mode, kernel actor executes synchronously.
  982. bool is_single_op_graph = execution_order.size() == 1;
  983. GraphExecutionStrategy strategy = graph_compiler_info.strategy_;
  984. if (strategy == GraphExecutionStrategy::kStep) {
  985. strategy = (is_single_op_graph ? strategy : GraphExecutionStrategy::kPipeline);
  986. }
  987. for (auto &kernel : execution_order) {
  988. if (IsKernelActor(kernel, graph_compiler_info.strategy_) && (!IsSkippedKernelActor(kernel))) {
  989. auto kernel_actor = std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context,
  990. memory_manager_aid_, debug_aid_, recorder_aid_, strategy);
  991. MS_EXCEPTION_IF_NULL(kernel_actor);
  992. InsertActor(kernel_actor.get());
  993. (void)kernel_actors.emplace_back(kernel_actor);
  994. auto front_node = graph->GetFrontAnfByBackendAnf(kernel);
  995. if (front_node != nullptr) {
  996. front_node_to_actor_[front_node] = kernel_actor;
  997. }
  998. }
  999. }
  1000. }
  1001. return kernel_actors;
  1002. }
  1003. LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const GraphCompilerInfo &graph_compiler_info) {
  1004. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
  1005. return nullptr;
  1006. }
  1007. auto loop_count = ConfigManager::GetInstance().iter_num();
  1008. auto actor_name = graph_compiler_info.name_ + "_LoopCountActor";
  1009. auto loop_count_actor =
  1010. std::make_shared<LoopCountActor>(actor_name, loop_count, memory_manager_aid_, debug_aid_, recorder_aid_);
  1011. MS_LOG(INFO) << "Create loop count actor: " << actor_name;
  1012. MS_EXCEPTION_IF_NULL(loop_count_actor);
  1013. // Cache the nodes which need continuous memory.
  1014. for (size_t index = 0; index < graph_compiler_info.graphs_.size(); ++index) {
  1015. const auto &graph = graph_compiler_info.graphs_[index];
  1016. MS_EXCEPTION_IF_NULL(graph);
  1017. auto &execution_order = graph->execution_order();
  1018. for (auto &kernel : execution_order) {
  1019. if (!AnfAlgo::IsCommunicationOp(kernel)) {
  1020. continue;
  1021. }
  1022. auto key = std::make_pair(kernel, graph_compiler_info.device_contexts_[index]);
  1023. auto value = std::make_pair(false, false);
  1024. if (AnfAlgo::GetInputTensorNum(kernel) > 1) {
  1025. value.first = true;
  1026. }
  1027. if (AnfAlgo::GetOutputTensorNum(kernel) > 1) {
  1028. value.second = true;
  1029. }
  1030. if ((value.first == true) || (value.second == true)) {
  1031. loop_count_actor->continuous_memory_nodes_[key] = value;
  1032. }
  1033. }
  1034. }
  1035. InsertActor(loop_count_actor.get());
  1036. return loop_count_actor;
  1037. }
  1038. OutputActorPtr GraphScheduler::BuildOutputActor(const GraphCompilerInfo &graph_compiler_info) {
  1039. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
  1040. return nullptr;
  1041. }
  1042. auto loop_count = ConfigManager::GetInstance().iter_num();
  1043. auto actor_name = graph_compiler_info.name_ + "_" + "OutputActor";
  1044. bool need_loop_count = (graph_compiler_info.strategy_ == GraphExecutionStrategy::kPipeline) ? true : false;
  1045. auto output_actor =
  1046. std::make_shared<OutputActor>(actor_name, loop_count, graph_compiler_info.outputs_num_, need_loop_count);
  1047. MS_LOG(INFO) << "Create output actor: " << actor_name;
  1048. MS_EXCEPTION_IF_NULL(output_actor);
  1049. InsertActor(output_actor.get());
  1050. return output_actor;
  1051. }
  1052. std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const ActorSet *actor_set,
  1053. GraphExecutionStrategy strategy) {
  1054. MS_EXCEPTION_IF_NULL(actor_set);
  1055. std::vector<KernelActorPtr> no_input_kernel_actors;
  1056. for (auto &kernel_actor : actor_set->kernel_actors_) {
  1057. MS_EXCEPTION_IF_NULL(kernel_actor);
  1058. // Framework will trigger kernel actor running in the step execution strategy.
  1059. if (strategy == GraphExecutionStrategy::kStep && IsSingleOpActorSet(actor_set)) {
  1060. kernel_actor->input_controls_num_++;
  1061. continue;
  1062. }
  1063. if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) {
  1064. // Check whether the kernel actor belongs to the root graph.
  1065. // In general, all no input nodes belong to the root funcgraph, and the corresponding gather actor should be
  1066. // empty. In control flow, the control arrow of the no input node in the sub funcgraph should be sent by the
  1067. // gather actor and should not be placed in the no input list.
  1068. const auto &graph = kernel_actor->kernel_->func_graph();
  1069. if (graph != nullptr) {
  1070. const auto &kernel_graph = dynamic_cast<KernelGraph *>(graph.get());
  1071. const auto func_graph = kernel_graph->GetFuncGraph();
  1072. if (func_graph != nullptr && FetchActor(func_graph->ToString()) != nullptr) {
  1073. continue;
  1074. }
  1075. }
  1076. (void)no_input_kernel_actors.emplace_back(kernel_actor);
  1077. }
  1078. }
  1079. return no_input_kernel_actors;
  1080. }
  1081. std::vector<SwitchActorPtr> GraphScheduler::BuildSwitchActor(const GraphCompilerInfo &graph_compiler_info) {
  1082. std::vector<SwitchActorPtr> switch_actors;
  1083. std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_backend_kernel;
  1084. for (const auto &pair : front_node_to_actor_) {
  1085. front_to_backend_kernel[pair.first] = pair.second->kernel_;
  1086. }
  1087. // Build switch actor by switch node and switchlayer node.
  1088. for (const auto &control_node : graph_compiler_info.control_nodes_) {
  1089. if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
  1090. AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
  1091. const auto func_graph = control_node->func_graph();
  1092. const auto branch_id = graph_compiler_info.control_node_parser_->GetBranchIDByFuncGraph(func_graph);
  1093. const auto &actor_name = control_node->DebugString();
  1094. auto switch_actor = std::make_shared<SwitchActor>(actor_name, graph_compiler_info.device_contexts_[0],
  1095. control_node->cast<CNodePtr>(), branch_id, false);
  1096. switch_actor->ParseInput(graph_compiler_info.control_node_parser_);
  1097. // Fetch all the input nodes of switch actor.
  1098. switch_actor->FetchInputNode(graph_compiler_info.control_node_parser_);
  1099. InsertActor(switch_actor.get());
  1100. (void)switch_actors.emplace_back(switch_actor);
  1101. }
  1102. }
  1103. // Build switch actor by return node.
  1104. const auto func_graphs_to_call_num = graph_compiler_info.control_node_parser_->func_graph_to_call_num_;
  1105. for (const auto &func_graph_to_call_num : func_graphs_to_call_num) {
  1106. const auto &return_node = func_graph_to_call_num.first->get_return();
  1107. MS_EXCEPTION_IF_NULL(return_node);
  1108. const auto &actor_name = return_node->DebugString();
  1109. auto switch_actor = std::make_shared<SwitchActor>(actor_name, graph_compiler_info.device_contexts_[0],
  1110. return_node->cast<CNodePtr>(), kInvalidBranchID, true);
  1111. switch_actor->ParseInput(graph_compiler_info.control_node_parser_);
  1112. // Fetch all the input nodes of switch actor.
  1113. switch_actor->FetchInputNode(graph_compiler_info.control_node_parser_);
  1114. InsertActor(switch_actor.get());
  1115. (void)switch_actors.emplace_back(switch_actor);
  1116. }
  1117. return switch_actors;
  1118. }
  1119. std::vector<GatherActorPtr> GraphScheduler::BuildGatherActor(const GraphCompilerInfo &graph_compiler_info) {
  1120. std::vector<GatherActorPtr> gather_actors;
  1121. const auto &loop_count_actor_name = graph_compiler_info.name_ + "_LoopCountActor";
  1122. const auto &loop_count_actor = FetchActor(loop_count_actor_name);
  1123. if (loop_count_actor == nullptr) {
  1124. return gather_actors;
  1125. }
  1126. const auto &output_actor_name = graph_compiler_info.name_ + "_" + "OutputActor";
  1127. const auto &output_actor = FetchActor(output_actor_name);
  1128. MS_EXCEPTION_IF_NULL(output_actor);
  1129. const auto parser = graph_compiler_info.control_node_parser_;
  1130. bool is_main_return = true;
  1131. // Each funcgraph has a return node, get the funcgraph from the return node, and create a gather actor.
  1132. std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_backend_kernel;
  1133. for (const auto &pair : front_node_to_actor_) {
  1134. front_to_backend_kernel[pair.first] = pair.second->kernel_;
  1135. }
  1136. for (const auto &control_node : graph_compiler_info.control_nodes_) {
  1137. const auto &func_graph = control_node->func_graph();
  1138. const auto &cnode = control_node->cast<CNodePtr>();
  1139. const auto &inputs = cnode->inputs();
  1140. const auto &return_node = func_graph->get_return();
  1141. if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimReturn)) {
  1142. // Root funcgraph does not need to create a gather actor.
  1143. if (is_main_return) {
  1144. is_main_return = false;
  1145. continue;
  1146. }
  1147. if (AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimPartial)) {
  1148. continue;
  1149. }
  1150. auto actor_name = func_graph->ToString();
  1151. std::vector<KernelWithIndex> parameters;
  1152. for (const auto &parameter : func_graph->get_inputs()) {
  1153. if (HasAbstractMonad(parameter) || HasAbstractRef(parameter)) {
  1154. continue;
  1155. }
  1156. (void)parameters.emplace_back(parameter, 0);
  1157. }
  1158. const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph);
  1159. const auto &output_switch_actor = FetchActor(return_node->DebugString());
  1160. MS_EXCEPTION_IF_NULL(output_switch_actor);
  1161. const auto &output_switch_aid = output_switch_actor->GetAID();
  1162. auto gather_actor =
  1163. std::make_shared<GatherActor>(actor_name, parameters, true, output_switch_aid, AID(), branch_id);
  1164. gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.control_node_parser_);
  1165. InsertActor(gather_actor.get());
  1166. (void)gather_actors.emplace_back(gather_actor);
  1167. }
  1168. }
  1169. // Create gather actor for call node which input0 of call node is a funcgraph.
  1170. for (const auto &control_node : graph_compiler_info.control_nodes_) {
  1171. const auto &cnode = control_node->cast<CNodePtr>();
  1172. const auto &inputs = cnode->inputs();
  1173. if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
  1174. // Collect the parameters.
  1175. std::vector<KernelWithIndex> parameters;
  1176. for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
  1177. if (HasAbstractMonad(inputs[i]) || (inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]))) {
  1178. continue;
  1179. }
  1180. (void)parameters.emplace_back(inputs[i], 0);
  1181. }
  1182. auto func_graph = control_node->func_graph();
  1183. auto actor_name = control_node->DebugString();
  1184. const auto branch_id = parser->GetBranchIDByFuncGraph(func_graph);
  1185. const auto &to_func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
  1186. const auto &to_actor = FetchActor(to_func_graph->ToString());
  1187. auto gather_actor =
  1188. std::make_shared<GatherActor>(actor_name, parameters, false, AID(), to_actor->GetAID(), branch_id);
  1189. gather_actor->FetchBackendInputNode(func_graph, graph_compiler_info.control_node_parser_);
  1190. InsertActor(gather_actor.get());
  1191. (void)gather_actors.emplace_back(gather_actor);
  1192. }
  1193. }
  1194. // Create gather actor for kernel graph which has a call input.
  1195. const auto &graph_with_device_contexts = graph_compiler_info.control_node_parser_->call_input_kernel_graphs_;
  1196. for (const auto &graph_with_device_context : graph_with_device_contexts) {
  1197. const auto &graph = graph_with_device_context.first;
  1198. const auto &parameters = FetchParameterbyKernelGraph(graph);
  1199. auto actor_name = graph->ToString();
  1200. auto gather_actor = std::make_shared<GatherActor>(actor_name, parameters, false, AID(), AID(), kInvalidBranchID);
  1201. InsertActor(gather_actor.get());
  1202. (void)gather_actors.emplace_back(gather_actor);
  1203. }
  1204. return gather_actors;
  1205. }
  1206. void GraphScheduler::LinkDataArrow(KernelActor *const to_actor, const GraphCompilerInfo &graph_compiler_info,
  1207. const KernelGraphPtr &graph, const KernelWithIndex &from_kernel_with_output_idx,
  1208. const KernelWithIndex &to_kernel_with_input_idx) {
  1209. MS_EXCEPTION_IF_NULL(to_actor);
  1210. MS_EXCEPTION_IF_NULL(graph);
  1211. auto from_kernel = from_kernel_with_output_idx.first;
  1212. if (from_kernel->isa<Parameter>() && graph_compiler_info.control_node_parser_->IsCallInputKernelGraph(graph)) {
  1213. const auto &kernel_with_index = GetFrontNodeByKernelGraph(from_kernel, graph);
  1214. const auto &real_front_node_with_index =
  1215. AnfAlgo::VisitKernelWithReturnType(kernel_with_index.first, kernel_with_index.second);
  1216. if (HasAbstractRef(real_front_node_with_index.first)) {
  1217. (void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second,
  1218. real_front_node_with_index.first);
  1219. return;
  1220. }
  1221. // When there is a call input in the kernel graph, all the inputs of the kernel graph needs to be sent by gather.
  1222. const auto actor_name = graph->ToString();
  1223. auto actor = FetchActor(actor_name);
  1224. MS_EXCEPTION_IF_NULL(actor);
  1225. LinkDataArrowForGatherActor(dynamic_cast<GatherActor *>(actor), to_actor, real_front_node_with_index,
  1226. to_kernel_with_input_idx);
  1227. return;
  1228. }
  1229. auto front_node = GetFrontNodeByBackendNode(from_kernel);
  1230. if (front_node != nullptr && IsGatherActor(front_node, actor_name_to_actor_)) {
  1231. // Link the data arrows of gather actor.
  1232. auto func_graph = GetFuncgraphByBackendNode(from_kernel);
  1233. if (func_graph == nullptr) {
  1234. MS_LOG(EXCEPTION) << "Cannot find funcgraph of node:" << AnfAlgo::GetNodeDebugString(from_kernel);
  1235. }
  1236. auto actor_name = func_graph->ToString();
  1237. const auto &from_actor = dynamic_cast<GatherActor *>(FetchActor(actor_name));
  1238. if (HasAbstractRef(from_kernel)) {
  1239. (void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, front_node);
  1240. return;
  1241. }
  1242. LinkDataArrowForGatherActor(from_actor, to_actor, {front_node, 0}, to_kernel_with_input_idx);
  1243. return;
  1244. }
  1245. auto kernel_type = KernelTransformType::kUnknown;
  1246. std::string kernel_name = "";
  1247. FetchKernelTransformTypeAndName(from_kernel, graph, graph_compiler_info, &kernel_type, &kernel_name);
  1248. auto from_actor = dynamic_cast<AbstractActor *>(FetchActor(kernel_name));
  1249. if (kKernelTypeToLinkFunc.count(kernel_type) > 0) {
  1250. (this->*kKernelTypeToLinkFunc[kernel_type])(from_actor, to_actor, from_kernel_with_output_idx,
  1251. to_kernel_with_input_idx, graph);
  1252. }
  1253. }
  1254. void GraphScheduler::LinkDataArrowForDeviceTensorStore(AbstractActor *const, KernelActor *const to_actor,
  1255. const KernelWithIndex &from_kernel_with_output_idx,
  1256. const KernelWithIndex &to_kernel_with_input_idx,
  1257. const KernelGraphPtr &graph) {
  1258. MS_EXCEPTION_IF_NULL(to_actor);
  1259. MS_EXCEPTION_IF_NULL(graph);
  1260. auto from_kernel = from_kernel_with_output_idx.first;
  1261. MS_EXCEPTION_IF_NULL(from_kernel);
  1262. auto device_tensor_store_key = FetchFrontNodeByBackendNode(from_kernel, graph);
  1263. (void)to_actor->device_tensor_store_keys_.emplace_back(to_kernel_with_input_idx.second, device_tensor_store_key);
  1264. }
  1265. void GraphScheduler::LinkDataArrowForInternalParameter(AbstractActor *const, KernelActor *to_actor,
  1266. const KernelWithIndex &from_kernel_with_output_idx,
  1267. const KernelWithIndex &to_kernel_with_input_idx,
  1268. const KernelGraphPtr &graph) {
  1269. MS_EXCEPTION_IF_NULL(to_actor);
  1270. MS_EXCEPTION_IF_NULL(graph);
  1271. auto internal_parameter = from_kernel_with_output_idx.first;
  1272. MS_EXCEPTION_IF_NULL(internal_parameter);
  1273. // Parameter ---> front node.
  1274. auto front_output_with_index = graph->GetFrontNodeByInternalParameter(internal_parameter);
  1275. auto front_output_node = front_output_with_index.first;
  1276. MS_EXCEPTION_IF_NULL(front_output_node);
  1277. if (IsSwitchActor(front_output_node)) {
  1278. auto switch_actor = dynamic_cast<SwitchActor *>(FetchActor(front_output_node->DebugString()));
  1279. MS_EXCEPTION_IF_NULL(switch_actor);
  1280. LinkDataArrowForSwitchActor(switch_actor, 0, to_actor, to_kernel_with_input_idx.second);
  1281. to_actor->input_datas_num_++;
  1282. return;
  1283. }
  1284. auto real_from_kernel_with_output_idx = from_kernel_with_output_idx;
  1285. AbstractActor *real_from_actor = nullptr;
  1286. KernelTransformType kernel_type;
  1287. if (IsPersistentDeviceTensor(front_output_node)) {
  1288. kernel_type = KernelTransformType::kDeviceTensorStore;
  1289. } else {
  1290. // front node ---> actor.
  1291. if (graph_output_to_actor_.count(front_output_with_index) == 0) {
  1292. MS_LOG(EXCEPTION) << "Can't find actor by front node:" << AnfAlgo::GetNodeDebugString(front_output_node)
  1293. << ", internal parameter:" << AnfAlgo::GetNodeDebugString(internal_parameter);
  1294. }
  1295. auto actor_pair = graph_output_to_actor_[front_output_with_index];
  1296. MS_EXCEPTION_IF_NULL(actor_pair.first);
  1297. MS_LOG(INFO) << "Graph " << graph->graph_id() << " internal parameter:" << internal_parameter->DebugString()
  1298. << ", corresponding front node:" << front_output_node->fullname_with_scope()
  1299. << " with index:" << front_output_with_index.second
  1300. << ", from actor:" << actor_pair.first->GetAID().Name() << " with index:" << actor_pair.second
  1301. << ", to actor:" << to_actor->GetAID().Name() << " with index:" << to_kernel_with_input_idx.second;
  1302. real_from_actor = actor_pair.first;
  1303. real_from_kernel_with_output_idx = KernelWithIndex(nullptr, actor_pair.second);
  1304. kernel_type = actor_pair.first->type_;
  1305. }
  1306. if (kKernelTypeToLinkFunc.count(kernel_type) == 0) {
  1307. MS_LOG(EXCEPTION) << "Invalid internal parameter:" << internal_parameter->DebugString() << ", type:" << kernel_type;
  1308. }
  1309. (this->*kKernelTypeToLinkFunc[kernel_type])(real_from_actor, to_actor, real_from_kernel_with_output_idx,
  1310. to_kernel_with_input_idx, graph);
  1311. }
  1312. void GraphScheduler::LinkDataArrowForBaseActor(AbstractActor *const from_actor, KernelActor *const to_actor,
  1313. const KernelWithIndex &from_kernel_with_output_idx,
  1314. const KernelWithIndex &to_kernel_with_input_idx) {
  1315. MS_EXCEPTION_IF_NULL(from_actor);
  1316. MS_EXCEPTION_IF_NULL(to_actor);
  1317. auto from_kernel = from_kernel_with_output_idx.first;
  1318. MS_EXCEPTION_IF_NULL(from_kernel);
  1319. auto from_output_index = from_kernel_with_output_idx.second;
  1320. auto to_input_index = to_kernel_with_input_idx.second;
  1321. // Get the position of from kernel in the data source actor.
  1322. auto position = from_actor->FetchNodePosition(from_kernel);
  1323. if ((from_actor->device_contexts_.size() <= position) || (to_actor->device_contexts_.size() <= 0)) {
  1324. MS_LOG(EXCEPTION) << "The device contexts size is wrong.";
  1325. }
  1326. if (IsNeedInsertCopyActor(from_actor->device_contexts_[position], to_actor->device_contexts_[0])) {
  1327. LinkDataArrowForCopyActor(from_actor, to_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
  1328. } else {
  1329. auto to_aid = to_actor->GetAID();
  1330. auto op_arrow = std::make_shared<DataArrow>(from_output_index, to_aid, to_input_index);
  1331. // If the from actor has the multi nodes, then use the real output position.
  1332. if (position != 0) {
  1333. op_arrow->from_output_index_ = SizeToInt(position);
  1334. }
  1335. (void)from_actor->output_data_arrows_.emplace_back(op_arrow);
  1336. to_actor->input_datas_num_++;
  1337. (void)to_actor->input_data_arrow_aids_.emplace_back(from_actor->GetAID());
  1338. // Update the reference count of device tensor.
  1339. UpdateRefCount(from_kernel, from_output_index);
  1340. }
  1341. }
  1342. void GraphScheduler::LinkDataArrowForDeviceDSActor(AbstractActor *const from_actor, KernelActor *const to_actor,
  1343. const KernelWithIndex &from_kernel_with_output_idx,
  1344. const KernelWithIndex &to_kernel_with_input_idx,
  1345. const KernelGraphPtr &) {
  1346. auto real_from_kernel_with_output_idx = from_kernel_with_output_idx;
  1347. if (real_from_kernel_with_output_idx.first == nullptr) {
  1348. auto device_ds_actor = dynamic_cast<DeviceQueueDataSourceActor *>(from_actor);
  1349. MS_EXCEPTION_IF_NULL(device_ds_actor);
  1350. real_from_kernel_with_output_idx.first = device_ds_actor->data_kernel_;
  1351. }
  1352. LinkDataArrowForBaseActor(from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx);
  1353. }
  1354. void GraphScheduler::LinkDataArrowForHostDSActor(AbstractActor *const from_actor, KernelActor *const to_actor,
  1355. const KernelWithIndex &from_kernel_with_output_idx,
  1356. const KernelWithIndex &to_kernel_with_input_idx,
  1357. const KernelGraphPtr &) {
  1358. auto host_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
  1359. MS_EXCEPTION_IF_NULL(host_ds_actor);
  1360. KernelWithIndex real_from_kernel_with_output_idx;
  1361. if (from_kernel_with_output_idx.first != nullptr) {
  1362. // Get the position of from kernel in the data source actor.
  1363. auto position = host_ds_actor->FetchNodePosition(from_kernel_with_output_idx.first);
  1364. real_from_kernel_with_output_idx.first = host_ds_actor->data_nodes_[position];
  1365. real_from_kernel_with_output_idx.second = from_kernel_with_output_idx.second;
  1366. } else {
  1367. real_from_kernel_with_output_idx.first = host_ds_actor->data_nodes_[from_kernel_with_output_idx.second];
  1368. real_from_kernel_with_output_idx.second = 0;
  1369. }
  1370. LinkDataArrowForBaseActor(from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx);
  1371. }
  1372. void GraphScheduler::LinkDataArrowForKernelActor(AbstractActor *const from_actor, KernelActor *const to_actor,
  1373. const KernelWithIndex &from_kernel_with_output_idx,
  1374. const KernelWithIndex &to_kernel_with_input_idx,
  1375. const KernelGraphPtr &) {
  1376. auto real_from_actor = from_actor;
  1377. auto real_from_kernel_with_output_idx = from_kernel_with_output_idx;
  1378. auto from_kernel = from_kernel_with_output_idx.first;
  1379. if (from_kernel == nullptr) {
  1380. auto kernel_actor = dynamic_cast<KernelActor *>(from_actor);
  1381. MS_EXCEPTION_IF_NULL(kernel_actor);
  1382. from_kernel = kernel_actor->kernel_;
  1383. real_from_kernel_with_output_idx.first = kernel_actor->kernel_;
  1384. }
  1385. MS_EXCEPTION_IF_NULL(from_kernel);
  1386. if (IsSkippedKernelActor(from_kernel)) {
  1387. real_from_kernel_with_output_idx = AnfAlgo::GetPrevNodeOutput(from_kernel, 0);
  1388. MS_EXCEPTION_IF_NULL(real_from_kernel_with_output_idx.first);
  1389. LinkControlArrowBySkippedNode(to_actor, from_kernel);
  1390. // Update the from kernel info by the real node info.
  1391. MS_LOG(INFO) << "Link data arrow for inplace node, aggregate node: "
  1392. << to_kernel_with_input_idx.first->fullname_with_scope()
  1393. << ", aggregate input index: " << to_kernel_with_input_idx.second
  1394. << ", skip node: " << from_kernel->fullname_with_scope()
  1395. << ", real node: " << real_from_kernel_with_output_idx.first->fullname_with_scope();
  1396. real_from_actor =
  1397. dynamic_cast<AbstractActor *>(FetchActor(real_from_kernel_with_output_idx.first->fullname_with_scope()));
  1398. MS_EXCEPTION_IF_NULL(real_from_actor);
  1399. }
  1400. LinkDataArrowForBaseActor(real_from_actor, to_actor, real_from_kernel_with_output_idx, to_kernel_with_input_idx);
  1401. }
  1402. void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor, KernelActor *const to_actor,
  1403. const KernelWithIndex &from_kernel_with_output_idx,
  1404. const KernelWithIndex &to_kernel_with_input_idx) {
  1405. MS_EXCEPTION_IF_NULL(from_actor);
  1406. MS_EXCEPTION_IF_NULL(to_actor);
  1407. auto from_kernel = from_kernel_with_output_idx.first;
  1408. MS_EXCEPTION_IF_NULL(from_kernel);
  1409. auto from_output_index = from_kernel_with_output_idx.second;
  1410. auto to_input_index = to_kernel_with_input_idx.second;
  1411. std::string name = "copy_from:" + from_actor->GetAID().Name() + "_node:" + from_kernel->fullname_with_scope() +
  1412. "_output_index:" + std::to_string(from_output_index);
  1413. CopyActor *copy_actor = dynamic_cast<CopyActor *>(FetchActor(name));
  1414. // Link between from actor and copy actor.
  1415. if (copy_actor == nullptr) {
  1416. // Create the copy actor.
  1417. auto copy_actor_shared_ptr = std::make_shared<CopyActor>(name, memory_manager_aid_);
  1418. (void)copy_actors_.emplace_back(copy_actor_shared_ptr);
  1419. copy_actor = copy_actor_shared_ptr.get();
  1420. MS_EXCEPTION_IF_NULL(copy_actor);
  1421. InsertActor(copy_actor);
  1422. // Get the position of from kernel in the data source actor.
  1423. auto position = from_actor->FetchNodePosition(from_kernel);
  1424. if ((from_actor->device_contexts_.size() <= position) || (to_actor->device_contexts_.size() <= 0)) {
  1425. MS_LOG(EXCEPTION) << "The device contexts size is wrong.";
  1426. }
  1427. auto from_device_context = from_actor->device_contexts_[position];
  1428. auto to_device_context = to_actor->device_contexts_[0];
  1429. auto from_device_tensor = AnfAlgo::GetMutableOutputAddr(from_kernel, from_output_index, false);
  1430. MS_EXCEPTION_IF_NULL(from_device_context);
  1431. MS_EXCEPTION_IF_NULL(to_device_context);
  1432. MS_EXCEPTION_IF_NULL(from_device_tensor);
  1433. auto op_arrow_to_copy = std::make_shared<DataArrow>(from_output_index, copy_actor->GetAID(), 0);
  1434. // If the from actor has the multi nodes, then use the real output position.
  1435. if (position != 0) {
  1436. op_arrow_to_copy->from_output_index_ = SizeToInt(position);
  1437. }
  1438. // Link.
  1439. (void)from_actor->output_data_arrows_.emplace_back(op_arrow_to_copy);
  1440. copy_actor->input_datas_num_++;
  1441. // Set the member of the copy actor.
  1442. auto to_kernel_mod = AnfAlgo::GetKernelMod(to_kernel_with_input_idx.first);
  1443. MS_EXCEPTION_IF_NULL(to_kernel_mod);
  1444. auto input_sizes = to_kernel_mod->GetInputSizeList();
  1445. if (to_input_index >= input_sizes.size()) {
  1446. MS_LOG(EXCEPTION) << "To input index(" << to_input_index << ") is out of size: " << input_sizes.size();
  1447. }
  1448. copy_actor->output_ = to_device_context->CreateDeviceAddress(
  1449. nullptr, input_sizes[to_input_index], from_device_tensor->format(), from_device_tensor->type_id());
  1450. (void)copy_actor->device_contexts_.emplace_back(from_device_context);
  1451. (void)copy_actor->device_contexts_.emplace_back(to_device_context);
  1452. // Update the reference count of device tensor.
  1453. UpdateRefCount(from_device_tensor.get());
  1454. }
  1455. // If the copy actor already exists, only need link between copy actor and to actor.
  1456. auto op_arrow_from_copy = std::make_shared<DataArrow>(0, to_actor->GetAID(), to_input_index);
  1457. (void)copy_actor->output_data_arrows_.emplace_back(op_arrow_from_copy);
  1458. to_actor->input_datas_num_++;
  1459. UpdateRefCount(copy_actor->output_.get());
  1460. }
  1461. void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node,
  1462. const KernelGraphPtr &graph) {
  1463. MS_EXCEPTION_IF_NULL(to_actor);
  1464. MS_EXCEPTION_IF_NULL(from_node);
  1465. // Find the real input node, include the monad node and make tuple node.
  1466. const std::vector<PrimitivePtr> return_types = {prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad,
  1467. prim::kPrimMakeTuple};
  1468. const auto &input_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(from_node, 0, false, return_types);
  1469. MS_EXCEPTION_IF_NULL(input_kernel_with_output_idx.first);
  1470. auto input_anfnode = input_kernel_with_output_idx.first;
  1471. CNodePtr input_cnode = nullptr;
  1472. if (input_anfnode->isa<CNode>()) {
  1473. input_cnode = input_anfnode->cast<CNodePtr>();
  1474. }
  1475. // Make tuple node needs to be expanded.
  1476. if (AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimMakeTuple)) {
  1477. MS_EXCEPTION_IF_NULL(input_cnode);
  1478. for (size_t i = 1; i < input_cnode->inputs().size(); ++i) {
  1479. LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), graph);
  1480. }
  1481. return;
  1482. }
  1483. const std::unordered_set<PrimitivePtr, PrimitiveHasher, PrimitiveEqual> recursion_prims = {
  1484. prim::kPrimDepend, prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimMakeTuple};
  1485. // Get the real depend input by monad node which needs to link the control arrow.
  1486. std::vector<AnfNodePtr> real_depend_inputs;
  1487. if (AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimDepend) ||
  1488. AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimLoad)) {
  1489. MS_EXCEPTION_IF_NULL(input_cnode);
  1490. real_depend_inputs.push_back(input_cnode->input(kDependAttachNodeIndex));
  1491. // The real input may be this scene: depend/load --> load/depend, so need add the control arrow for real input
  1492. // node in this scene.
  1493. if (AnfAlgo::IsOneOfPrimitiveCNode(input_cnode->input(kRealInputIndexInDepend), recursion_prims)) {
  1494. real_depend_inputs.push_back(input_cnode->input(kRealInputIndexInDepend));
  1495. }
  1496. } else if (AnfAlgo::CheckPrimitiveType(input_anfnode, prim::kPrimUpdateState)) {
  1497. MS_EXCEPTION_IF_NULL(input_cnode);
  1498. for (size_t i = kUpdateStateRealInput; i < input_cnode->inputs().size(); ++i) {
  1499. real_depend_inputs.push_back(input_cnode->input(i));
  1500. }
  1501. } else {
  1502. real_depend_inputs.push_back(input_anfnode);
  1503. }
  1504. for (const auto &real_depend_input : real_depend_inputs) {
  1505. auto real_depend_input_with_idx = AnfAlgo::VisitKernelWithReturnType(real_depend_input, 0, false, return_types);
  1506. auto real_depend_kernel = real_depend_input_with_idx.first;
  1507. // The monad node and make tuple node need recursion.
  1508. if (AnfAlgo::IsOneOfPrimitiveCNode(real_depend_kernel, recursion_prims)) {
  1509. LinkControlArrowByAutoMonad(to_actor, real_depend_kernel, graph);
  1510. continue;
  1511. }
  1512. KernelActor *from_actor = nullptr;
  1513. if (IsKernelActor(real_depend_kernel)) {
  1514. from_actor = dynamic_cast<KernelActor *>(FetchActor(real_depend_kernel->fullname_with_scope()));
  1515. } else if (IsInternalParameter(real_depend_kernel, graph)) {
  1516. auto front_output_with_index = graph->GetFrontNodeByInternalParameter(real_depend_kernel);
  1517. MS_EXCEPTION_IF_NULL(front_output_with_index.first);
  1518. if (IsKernelActor(front_output_with_index.first)) {
  1519. if (graph_output_to_actor_.count(front_output_with_index) == 0) {
  1520. MS_LOG(EXCEPTION) << "Can't find actor by front node:" << front_output_with_index.first->DebugString();
  1521. }
  1522. from_actor = dynamic_cast<KernelActor *>(graph_output_to_actor_[front_output_with_index].first);
  1523. }
  1524. }
  1525. if (from_actor == nullptr) {
  1526. continue;
  1527. }
  1528. MS_LOG(INFO) << "Link control arrow by auto monad, from actor: " << from_actor->GetAID().Name()
  1529. << ", to actor: " << to_actor->GetAID().Name();
  1530. (void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
  1531. to_actor->input_controls_num_++;
  1532. }
  1533. }
  1534. void GraphScheduler::LinkControlArrowBySkippedNode(KernelActor *to_actor, const AnfNodePtr &skipped_node) {
  1535. MS_EXCEPTION_IF_NULL(to_actor);
  1536. MS_EXCEPTION_IF_NULL(skipped_node);
  1537. auto to_aid = to_actor->GetAID();
  1538. // Link the control arrow from all the inputs of skipped node to the user of skipped node.
  1539. auto input_num = AnfAlgo::GetInputTensorNum(skipped_node);
  1540. for (size_t i = 0; i < input_num; ++i) {
  1541. auto kernel_with_index = AnfAlgo::GetPrevNodeOutput(skipped_node, i, false);
  1542. MS_EXCEPTION_IF_NULL(kernel_with_index.first);
  1543. auto from_actor = dynamic_cast<KernelActor *>(FetchActor(kernel_with_index.first->fullname_with_scope()));
  1544. MS_EXCEPTION_IF_NULL(from_actor);
  1545. MS_LOG(INFO) << "Link control arrow by skipped node: " << skipped_node->fullname_with_scope()
  1546. << ", from actor: " << from_actor->GetAID().Name() << ", to actor: " << to_actor->GetAID().Name();
  1547. (void)from_actor->output_control_arrows_.emplace_back(to_aid);
  1548. to_actor->input_controls_num_++;
  1549. }
  1550. }
  1551. void GraphScheduler::LinkControlArrowBySendRecvNodes(const KernelGraphPtr &graph) {
  1552. MS_EXCEPTION_IF_NULL(graph);
  1553. for (auto &from_iter : graph->allreduce_from_send_recv_pairs()) {
  1554. auto to_allreduce_node = from_iter.first;
  1555. MS_LOG(INFO) << "Link control arrow for to_allreduce_node: " << to_allreduce_node->fullname_with_scope();
  1556. auto from_send_node = from_iter.second.first;
  1557. auto from_recv_node = from_iter.second.second;
  1558. auto to_allreduce_actor = dynamic_cast<KernelActor *>(FetchActor(to_allreduce_node->fullname_with_scope()));
  1559. auto from_send_actor = dynamic_cast<KernelActor *>(FetchActor(from_send_node->fullname_with_scope()));
  1560. auto from_recv_actor = dynamic_cast<KernelActor *>(FetchActor(from_recv_node->fullname_with_scope()));
  1561. // inputs of to_allreduce_actor --> from_send_actor
  1562. for (auto &input_aid : to_allreduce_actor->input_data_arrow_aids_) {
  1563. auto input_actor = dynamic_cast<KernelActor *>(FetchActor(input_aid.Name()));
  1564. if (input_actor != nullptr) {
  1565. (void)input_actor->output_control_arrows_.emplace_back(from_send_actor->GetAID());
  1566. from_send_actor->input_controls_num_++;
  1567. }
  1568. }
  1569. // from_send_actor --> from_recv_actor
  1570. (void)from_send_actor->output_control_arrows_.emplace_back(from_recv_actor->GetAID());
  1571. from_recv_actor->input_controls_num_++;
  1572. // from_recv_actor --> to_allreduce_actor
  1573. (void)from_recv_actor->output_control_arrows_.emplace_back(to_allreduce_actor->GetAID());
  1574. to_allreduce_actor->input_controls_num_++;
  1575. }
  1576. for (auto &to_iter : graph->allreduce_to_send_recv_pairs()) {
  1577. auto from_allreduce_node = to_iter.first;
  1578. MS_LOG(INFO) << "Link control arrow for from_allreduce_node: " << from_allreduce_node->fullname_with_scope();
  1579. auto to_send_node = to_iter.second.first;
  1580. auto to_recv_node = to_iter.second.second;
  1581. auto from_allreduce_actor = dynamic_cast<KernelActor *>(FetchActor(from_allreduce_node->fullname_with_scope()));
  1582. auto to_send_actor = dynamic_cast<KernelActor *>(FetchActor(to_send_node->fullname_with_scope()));
  1583. auto to_recv_actor = dynamic_cast<KernelActor *>(FetchActor(to_recv_node->fullname_with_scope()));
  1584. // from_allreduce_actor --> to_send_actor
  1585. (void)from_allreduce_actor->output_control_arrows_.emplace_back(to_send_actor->GetAID());
  1586. to_send_actor->input_controls_num_++;
  1587. // to_send_actor --> to_recv_actor
  1588. (void)to_send_actor->output_control_arrows_.emplace_back(to_recv_actor->GetAID());
  1589. to_recv_actor->input_controls_num_++;
  1590. // to_recv_actor --> outputs of from_allreduce_actor
  1591. for (auto &output_data_arrow : from_allreduce_actor->output_data_arrows_) {
  1592. auto output_actor = dynamic_cast<KernelActor *>(FetchActor(output_data_arrow->to_op_id_.Name()));
  1593. if (output_actor != nullptr) {
  1594. (void)to_recv_actor->output_control_arrows_.emplace_back(output_actor->GetAID());
  1595. output_actor->input_controls_num_++;
  1596. }
  1597. }
  1598. // In the scene of allreduce op and computing op parallel multi stream, the input memory of allreduce can be
  1599. // reused only when the recv node runs finished, which is expressed by the reference count increased.
  1600. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(from_allreduce_node); ++i) {
  1601. auto device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(from_allreduce_node, i, false);
  1602. MS_EXCEPTION_IF_NULL(device_tensor);
  1603. UpdateRefCount(device_tensor.get());
  1604. (void)to_recv_actor->external_reference_tensors_.emplace_back(device_tensor.get());
  1605. }
  1606. }
  1607. }
  1608. void GraphScheduler::LinkControlArrowByCommunicationNode(const std::vector<CNodePtr> &communication_nodes,
  1609. const GraphCompilerInfo &graph_compiler_info) {
  1610. const size_t kCommunicationNodesMinNum = 2;
  1611. if (communication_nodes.size() < kCommunicationNodesMinNum) {
  1612. return;
  1613. }
  1614. // Ensure communication node to execute orderly.
  1615. for (size_t i = 1; i < communication_nodes.size(); ++i) {
  1616. auto from_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i - 1]->fullname_with_scope()));
  1617. auto to_actor = dynamic_cast<KernelActor *>(FetchActor(communication_nodes[i]->fullname_with_scope()));
  1618. MS_EXCEPTION_IF_NULL(from_actor);
  1619. MS_EXCEPTION_IF_NULL(to_actor);
  1620. (void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
  1621. to_actor->input_controls_num_++;
  1622. }
  1623. // Ensure all actors execute orderly to optimize the execution performance in the multi device scenario currently.
  1624. // Using the multi stream to optimize the performance in the future.
  1625. for (auto &graph : graph_compiler_info.graphs_) {
  1626. auto &execution_order = graph->execution_order();
  1627. for (size_t i = 1; i < execution_order.size(); ++i) {
  1628. auto from_actor = dynamic_cast<KernelActor *>(FetchActor(execution_order[i - 1]->fullname_with_scope()));
  1629. auto to_actor = dynamic_cast<KernelActor *>(FetchActor(execution_order[i]->fullname_with_scope()));
  1630. if ((from_actor != nullptr) && (to_actor != nullptr)) {
  1631. (void)from_actor->output_control_arrows_.emplace_back(to_actor->GetAID());
  1632. to_actor->input_controls_num_++;
  1633. }
  1634. }
  1635. }
  1636. }
  1637. void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const ActorSet *actor_set,
  1638. const ControlNodeParserPtr &parser) {
  1639. MS_EXCEPTION_IF_NULL(actor_set);
  1640. // There is no loop count actor in step mode.
  1641. if (loop_count_actor == nullptr) {
  1642. return;
  1643. }
  1644. // Collect the actors which have no output.
  1645. std::vector<MemoryAwareActor *> no_output_actors;
  1646. for (auto &kernel_actor : actor_set->kernel_actors_) {
  1647. // The no output kernel control side in subgraph needs to be connected to the corresponding output switch actor.
  1648. if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0) &&
  1649. parser->IsKernelInRootFuncGraph(kernel_actor->kernel_)) {
  1650. MS_EXCEPTION_IF_NULL(kernel_actor->kernel_);
  1651. MS_LOG(INFO) << kernel_actor->kernel_->fullname_with_scope() << " is not real used by other nodes.";
  1652. (void)no_output_actors.emplace_back(kernel_actor.get());
  1653. }
  1654. }
  1655. for (auto &data_actor : actor_set->data_source_actors_) {
  1656. if ((data_actor->output_data_arrows_.size() == 0) && (data_actor->output_control_arrows_.size() == 0)) {
  1657. (void)no_output_actors.emplace_back(data_actor.get());
  1658. }
  1659. }
  1660. for (auto &copy_actor : copy_actors_) {
  1661. if ((copy_actor->output_data_arrows_.size() == 0) && (copy_actor->output_control_arrows_.size() == 0)) {
  1662. (void)no_output_actors.emplace_back(copy_actor.get());
  1663. }
  1664. }
  1665. // No output actor --> loop count actor.
  1666. for (auto &no_output_actor : no_output_actors) {
  1667. (void)no_output_actor->output_control_arrows_.emplace_back(loop_count_actor->GetAID());
  1668. loop_count_actor->input_controls_num_++;
  1669. }
  1670. // Loop count actor --> data source actor.
  1671. for (auto &data_source_actor : actor_set->data_source_actors_) {
  1672. MS_EXCEPTION_IF_NULL(data_source_actor);
  1673. (void)loop_count_actor->data_source_aids_.emplace_back(data_source_actor->GetAID());
  1674. }
  1675. // Loop count actor --> no input kernel actor.
  1676. for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
  1677. MS_EXCEPTION_IF_NULL(no_input_kernel_actor);
  1678. (void)loop_count_actor->no_input_kernel_aids_.emplace_back(no_input_kernel_actor->GetAID());
  1679. no_input_kernel_actor->input_controls_num_++;
  1680. }
  1681. // Loop count actor --> output actor.
  1682. MS_EXCEPTION_IF_NULL(actor_set->output_actor_);
  1683. loop_count_actor->output_aid_ = actor_set->output_actor_->GetAID();
  1684. }
  1685. void GraphScheduler::LinkOutputResultArrowForOutputActor(OutputActor *to_actor,
  1686. const GraphCompilerInfo &graph_compiler_info) {
  1687. if (graph_compiler_info.strategy_ == GraphExecutionStrategy::kStep) {
  1688. return;
  1689. }
  1690. MS_EXCEPTION_IF_NULL(to_actor);
  1691. size_t number = 0;
  1692. for (const auto &graph : graph_compiler_info.graphs_) {
  1693. MS_EXCEPTION_IF_NULL(graph);
  1694. ++number;
  1695. auto outputs = AnfAlgo::GetAllOutputWithIndex(graph->output());
  1696. std::set<std::vector<size_t>> unique_output_positions;
  1697. std::set<KernelWithIndex> unique_outputs;
  1698. for (const auto &output : outputs) {
  1699. if (IsInternalParameter(output.first, graph)) {
  1700. MS_LOG(INFO) << "Ignore the internal parameter node:" << output.first->DebugString();
  1701. continue;
  1702. }
  1703. (void)unique_outputs.insert(output);
  1704. }
  1705. for (const auto &output_with_index : unique_outputs) {
  1706. MS_EXCEPTION_IF_NULL(output_with_index.first);
  1707. auto origin_output_with_index = FetchFrontNodeWithIndexByGraphOutput(output_with_index, graph);
  1708. const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
  1709. if (iter == graph_compiler_info.origin_outputs_order_.end()) {
  1710. continue;
  1711. }
  1712. // Skip duplicate position.
  1713. if (unique_output_positions.count(iter->second) > 0) {
  1714. continue;
  1715. }
  1716. (void)unique_output_positions.insert(iter->second);
  1717. for (auto &output_position : iter->second) {
  1718. to_actor->device_contexts_[output_position] = graph_compiler_info.device_contexts_[number - 1];
  1719. // The device tensor of graph out need be taken over by host tensor, so set the max reference count.
  1720. UpdateRefCount(output_with_index.first, output_with_index.second, true);
  1721. // The graph output is from device tensor store.
  1722. if (IsPersistentDeviceTensor(output_with_index.first)) {
  1723. (void)to_actor->device_tensor_store_keys_.emplace_back(output_position, output_with_index.first);
  1724. continue;
  1725. }
  1726. // The graph output is from kernel actor or data source actor.
  1727. auto kernel_type = KernelTransformType::kUnknown;
  1728. std::string kernel_name = "";
  1729. FetchKernelTransformTypeAndName(output_with_index.first, graph, graph_compiler_info, &kernel_type,
  1730. &kernel_name);
  1731. auto from_actor = dynamic_cast<AbstractActor *>(FetchActor(kernel_name));
  1732. if (from_actor == nullptr) {
  1733. continue;
  1734. }
  1735. auto op_arrow = std::make_shared<DataArrow>(output_with_index.second, to_actor->GetAID(), output_position);
  1736. auto position = from_actor->FetchNodePosition(output_with_index.first);
  1737. // If the from actor has the multi nodes, then use the real output position.
  1738. if (position != 0) {
  1739. op_arrow->from_output_index_ = SizeToInt(position);
  1740. }
  1741. (void)from_actor->output_result_arrows_.emplace_back(op_arrow);
  1742. if (kernel_type == KernelTransformType::kHostDataSourceActor) {
  1743. auto host_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(from_actor);
  1744. MS_EXCEPTION_IF_NULL(host_queue_ds_actor);
  1745. UpdateRefCount(host_queue_ds_actor->data_nodes_[position], output_with_index.second, true);
  1746. }
  1747. }
  1748. }
  1749. }
  1750. }
  1751. void GraphScheduler::LinkOutputResultArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info,
  1752. const ActorSet *actor_set) {
  1753. const auto &to_actor = actor_set->output_actor_;
  1754. const auto &loop_count_actor = actor_set->loop_count_actor_;
  1755. if (to_actor == nullptr || loop_count_actor == nullptr) {
  1756. return;
  1757. }
  1758. // When there is a call node in the output, the output will be sent to the output actor by the switch actor of
  1759. // the funcgraph called by the call node.
  1760. const auto &outputs = graph_compiler_info.origin_outputs_order_;
  1761. for (const auto &output : outputs) {
  1762. const auto &output_node = output.first.first;
  1763. const auto &output_index = output.first.second;
  1764. const auto output_poses = output.second;
  1765. if (IsCallNode(output_node)) {
  1766. const auto &func_graphs = FetchFuncGraphbyCallNode(output_node);
  1767. for (const auto func_graph : func_graphs) {
  1768. const auto &actor_name = func_graph->get_return()->DebugString();
  1769. auto actor = FetchActor(actor_name);
  1770. MS_EXCEPTION_IF_NULL(actor);
  1771. auto switch_actor = dynamic_cast<SwitchActor *>(actor);
  1772. MS_EXCEPTION_IF_NULL(switch_actor);
  1773. // Set branch index into switch actor.
  1774. size_t branch_index = switch_actor->branch_id_to_index_.size();
  1775. if (switch_actor->branch_id_to_index_.find(kMainBranchID) != switch_actor->branch_id_to_index_.end()) {
  1776. branch_index = switch_actor->branch_id_to_index_[kMainBranchID];
  1777. } else {
  1778. switch_actor->branch_id_to_index_[kMainBranchID] = branch_index;
  1779. }
  1780. // Link output result arrow.
  1781. for (const auto output_pos : output_poses) {
  1782. auto op_arrow = std::make_shared<DataArrow>(output_index, to_actor->GetAID(), output_pos);
  1783. to_actor->device_contexts_[output_pos] = switch_actor->device_context_;
  1784. (void)switch_actor->output_branch_result_arrows_[branch_index].emplace_back(op_arrow);
  1785. }
  1786. }
  1787. }
  1788. }
  1789. const auto &switch_actors = actor_set->switch_actors_;
  1790. for (const auto &from_actor : switch_actors) {
  1791. MS_EXCEPTION_IF_NULL(from_actor);
  1792. auto origin_output_with_index = KernelWithIndex(from_actor->node_, 0);
  1793. const auto &iter = graph_compiler_info.origin_outputs_order_.find(origin_output_with_index);
  1794. if (iter == graph_compiler_info.origin_outputs_order_.end()) {
  1795. continue;
  1796. }
  1797. // If the switch actor is in the output list, the output of switch actor should be sent to the output actor.
  1798. // And need to link a control arrow to the loop count actor.
  1799. for (const auto pos : iter->second) {
  1800. to_actor->device_contexts_[pos] = from_actor->device_context_;
  1801. }
  1802. for (size_t i = 0; i < from_actor->branch_inputs_pos_.size(); ++i) {
  1803. const auto &input_pos = from_actor->branch_inputs_pos_[i];
  1804. if (input_pos.empty()) {
  1805. MS_LOG(EXCEPTION) << "Invalid input num in switch actor:" << from_actor->GetAID();
  1806. }
  1807. for (const auto pos : iter->second) {
  1808. auto op_arrow = std::make_shared<DataArrow>(0, to_actor->GetAID(), pos);
  1809. (void)from_actor->output_branch_result_arrows_[i].emplace_back(op_arrow);
  1810. }
  1811. (void)from_actor->output_branch_control_arrows_[i].emplace_back(loop_count_actor->GetAID());
  1812. }
  1813. loop_count_actor->input_controls_num_++;
  1814. }
  1815. }
  1816. void GraphScheduler::LinkDeviceTensorStoreForAutoMonadActor(const std::vector<KernelActor *> &auto_monad_actors) {
  1817. const size_t kNeedUpdateDeviceTensorStoreNum = 2;
  1818. for (auto &kernel_actor : auto_monad_actors) {
  1819. MS_EXCEPTION_IF_NULL(kernel_actor);
  1820. for (auto &device_tensor_store_key : kernel_actor->device_tensor_store_keys_) {
  1821. auto device_tensors = DeviceTensorStore::GetInstance().Fetch(device_tensor_store_key.second.get());
  1822. if (device_tensors.size() < kNeedUpdateDeviceTensorStoreNum) {
  1823. continue;
  1824. }
  1825. // Create the copy actor.
  1826. std::string name = "copy_from:" + kernel_actor->GetAID().Name() +
  1827. "_device_tensor_store:" + device_tensor_store_key.second->fullname_with_scope();
  1828. if (FetchActor(name) != nullptr) {
  1829. continue;
  1830. }
  1831. auto copy_actor = std::make_shared<CopyActor>(name, memory_manager_aid_);
  1832. MS_EXCEPTION_IF_NULL(copy_actor);
  1833. (void)copy_actors_.emplace_back(copy_actor);
  1834. InsertActor(copy_actor.get());
  1835. // Set the member of the copy actor.
  1836. (void)copy_actor->device_tensor_store_keys_.emplace_back(0, device_tensor_store_key.second);
  1837. auto input_device_context = kernel_actor->device_contexts_[0];
  1838. (void)copy_actor->device_contexts_.emplace_back(input_device_context);
  1839. auto another_device_tensor = (device_tensors[0]->DeviceType() == input_device_context->GetDeviceAddressType())
  1840. ? device_tensors[1]
  1841. : device_tensors[0];
  1842. MS_EXCEPTION_IF_NULL(another_device_tensor);
  1843. auto another_device_type = another_device_tensor->DeviceType();
  1844. const auto &another_device_context = device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext(
  1845. {device::kDeviceTypeToName.at(another_device_type), input_device_context->device_context_key().device_id_});
  1846. MS_EXCEPTION_IF_NULL(another_device_context);
  1847. (void)copy_actor->device_contexts_.emplace_back(another_device_context);
  1848. MS_LOG(INFO) << "The kernel actor: " << kernel_actor->GetAID().Name()
  1849. << "has control arrows number:" << kernel_actor->output_control_arrows_.size();
  1850. // Link from copy actor to kernel actor users.
  1851. for (auto &output_contorl : kernel_actor->output_control_arrows_) {
  1852. (void)copy_actor->output_control_arrows_.emplace_back(output_contorl);
  1853. }
  1854. // Move the control arrows from kernel actor to kernel actor users.
  1855. kernel_actor->output_control_arrows_.clear();
  1856. // Link from kernel actor to copy actor.
  1857. (void)kernel_actor->output_control_arrows_.emplace_back(copy_actor->GetAID());
  1858. copy_actor->input_controls_num_++;
  1859. }
  1860. }
  1861. }
  1862. void GraphScheduler::PrepareInputNodeForSwitchActor(const std::vector<AnfNodePtr> &control_nodes) {
  1863. for (const auto &node : control_nodes) {
  1864. CNodePtr cnode = node->cast<CNodePtr>();
  1865. auto inputs = cnode->inputs();
  1866. // Before link data arrow, parameters of the call node in switch-call need to be add to the switch actor.
  1867. if (inputs[0]->isa<CNode>()) {
  1868. auto actor = FetchActor(inputs[0]->DebugString());
  1869. MS_EXCEPTION_IF_NULL(actor);
  1870. auto switch_actor = dynamic_cast<SwitchActor *>(actor);
  1871. MS_EXCEPTION_IF_NULL(switch_actor);
  1872. for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
  1873. if (HasAbstractMonad(inputs[i])) {
  1874. continue;
  1875. }
  1876. switch_actor->AddCommonInput(inputs[i]);
  1877. }
  1878. }
  1879. }
  1880. }
  1881. void GraphScheduler::LinkArrowByControlNode(const GraphCompilerInfo &graph_compiler_info, ActorSet *const actor_set) {
  1882. PrepareInputNodeForSwitchActor(graph_compiler_info.control_nodes_);
  1883. for (const auto &node : graph_compiler_info.control_nodes_) {
  1884. CNodePtr cnode = node->cast<CNodePtr>();
  1885. const auto &from_func_graph = node->func_graph();
  1886. auto inputs = cnode->inputs();
  1887. // Link data arrow for switch node.
  1888. if (AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitch) ||
  1889. AnfAlgo::CheckPrimitiveType(node, prim::kPrimSwitchLayer)) {
  1890. auto actor = actor_name_to_actor_[node->DebugString()];
  1891. MS_EXCEPTION_IF_NULL(actor);
  1892. auto switch_actor = dynamic_cast<SwitchActor *>(actor);
  1893. MS_EXCEPTION_IF_NULL(switch_actor);
  1894. LinkDataArrowForSwitchActor(graph_compiler_info, switch_actor);
  1895. } else if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
  1896. // Link the data arrow for the input of the call node.
  1897. const auto &actor_name = node->DebugString();
  1898. auto actor = FetchActor(actor_name);
  1899. MS_EXCEPTION_IF_NULL(actor);
  1900. auto gather_actor = dynamic_cast<GatherActor *>(actor);
  1901. MS_EXCEPTION_IF_NULL(gather_actor);
  1902. const auto &func_graph = GetValueNode<FuncGraphPtr>(inputs[0]);
  1903. MS_EXCEPTION_IF_NULL(func_graph);
  1904. const auto &to_actor = FetchActor(func_graph->ToString());
  1905. MS_EXCEPTION_IF_NULL(to_actor);
  1906. size_t persist_input_num = 0;
  1907. for (size_t i = kCallInputStartPos; i < inputs.size(); ++i) {
  1908. MS_EXCEPTION_IF_NULL(actor);
  1909. if (inputs[i]->isa<ValueNode>()) {
  1910. const auto &node_value = inputs[i]->cast<ValueNodePtr>()->value();
  1911. if (!node_value->isa<tensor::Tensor>()) {
  1912. persist_input_num++;
  1913. continue;
  1914. }
  1915. (void)gather_actor->device_tensor_store_keys_.emplace_back(i - kCallInputStartPos - persist_input_num,
  1916. inputs[i].get());
  1917. gather_actor->device_contexts_[i - kCallInputStartPos - persist_input_num] =
  1918. graph_compiler_info.control_node_parser_->GetFrontValueNodeDeviceContext(inputs[i]);
  1919. } else if ((inputs[i]->isa<Parameter>() && HasAbstractRef(inputs[i]->cast<ParameterPtr>())) ||
  1920. AnfAlgo::CheckPrimitiveType(inputs[i], prim::kPrimUpdateState) || HasAbstractMonad(inputs[i])) {
  1921. persist_input_num++;
  1922. continue;
  1923. } else {
  1924. const auto &input_with_index = AnfAlgo::VisitKernelWithReturnType(inputs[i], 0);
  1925. LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, actor,
  1926. i - kCallInputStartPos - persist_input_num);
  1927. }
  1928. auto op_arrow = std::make_shared<DataArrow>(i - kCallInputStartPos - persist_input_num, to_actor->GetAID(),
  1929. i - kCallInputStartPos - persist_input_num);
  1930. (void)gather_actor->output_data_arrows_.emplace_back(op_arrow);
  1931. }
  1932. }
  1933. }
  1934. // Link arrow for switch actor of subgraph output.
  1935. for (const auto &func_graph_with_call_num : graph_compiler_info.control_node_parser_->func_graph_to_call_num_) {
  1936. const auto &func_graph = func_graph_with_call_num.first;
  1937. MS_EXCEPTION_IF_NULL(func_graph);
  1938. auto actor = FetchActor(func_graph->get_return()->DebugString());
  1939. MS_EXCEPTION_IF_NULL(actor);
  1940. auto switch_actor = dynamic_cast<SwitchActor *>(actor);
  1941. MS_EXCEPTION_IF_NULL(switch_actor);
  1942. LinkDataArrowForSwitchActor(graph_compiler_info, switch_actor);
  1943. }
  1944. // Link arrow for gather actor for call input kernel graph.
  1945. for (const auto &call_input_kernel_graph : graph_compiler_info.control_node_parser_->call_input_kernel_graphs_) {
  1946. const auto &kernel_graph = call_input_kernel_graph.first;
  1947. MS_EXCEPTION_IF_NULL(kernel_graph);
  1948. auto actor = FetchActor(kernel_graph->ToString());
  1949. MS_EXCEPTION_IF_NULL(actor);
  1950. auto gather_actor = dynamic_cast<GatherActor *>(actor);
  1951. MS_EXCEPTION_IF_NULL(gather_actor);
  1952. for (size_t i = 0; i < gather_actor->data_nodes_.size(); ++i) {
  1953. const auto &input_with_index = gather_actor->data_nodes_[i];
  1954. const auto &from_func_graph = kernel_graph->GetFuncGraph();
  1955. LinkDataArrowByControlNode(graph_compiler_info, input_with_index, from_func_graph, gather_actor, i);
  1956. }
  1957. }
  1958. LinkBranchArrowForSwitchActor(graph_compiler_info);
  1959. LinkBranchArrowForGatherActor(graph_compiler_info);
  1960. LinkControlArrowForGatherActor(&(actor_set->kernel_actors_), graph_compiler_info.graphs_,
  1961. graph_compiler_info.control_node_parser_);
  1962. LinkControlArrowForSwitchActor(&(actor_set->switch_actors_), actor_set->loop_count_actor_.get(),
  1963. graph_compiler_info.origin_outputs_order_);
  1964. LinkOutputResultArrowForSwitchActor(graph_compiler_info, actor_set);
  1965. }
  1966. void GraphScheduler::LinkDataArrowForGatherActor(GatherActor *const from_actor, KernelActor *const to_actor,
  1967. const KernelWithIndex &front_node_with_index,
  1968. const KernelWithIndex &to_node_with_index) {
  1969. MS_EXCEPTION_IF_NULL(from_actor);
  1970. MS_EXCEPTION_IF_NULL(to_actor);
  1971. MS_EXCEPTION_IF_NULL(front_node_with_index.first);
  1972. auto position = from_actor->FetchDataNodePosition(front_node_with_index);
  1973. auto op_arrow = std::make_shared<DataArrow>(position, to_actor->GetAID(), to_node_with_index.second);
  1974. (void)from_actor->output_data_arrows_.emplace_back(op_arrow);
  1975. to_actor->input_datas_num_++;
  1976. }
  1977. void GraphScheduler::LinkDataArrowByCallInput(const KernelWithIndex &call_node_with_index,
  1978. const ControlNodeParserPtr &parser, const FuncGraphPtr &from_func_graph,
  1979. OpActor<DeviceTensor> *const to_actor, const size_t to_index) {
  1980. // Fetch all the funcgraph that call node would call.
  1981. const auto cnode = call_node_with_index.first->cast<CNodePtr>();
  1982. std::vector<FuncGraphPtr> func_graphs = FetchFuncGraphbyCallNode(cnode);
  1983. // Collect the output of each funcgraph.
  1984. for (const auto &func_graph : func_graphs) {
  1985. const auto actor_name = func_graph->get_return()->DebugString();
  1986. auto actor = FetchActor(actor_name);
  1987. MS_EXCEPTION_IF_NULL(actor);
  1988. auto switch_actor = dynamic_cast<SwitchActor *>(actor);
  1989. MS_EXCEPTION_IF_NULL(switch_actor);
  1990. const size_t branch_index = switch_actor->branch_id_to_index_.size();
  1991. const auto &func_graph_to_branch_id = parser->func_graph_to_branch_id_;
  1992. const auto &iter = func_graph_to_branch_id.find(from_func_graph);
  1993. int branch_id = kMainBranchID;
  1994. if (iter != func_graph_to_branch_id.end()) {
  1995. branch_id = iter->second;
  1996. }
  1997. if (switch_actor->branch_id_to_index_.find(branch_id) != switch_actor->branch_id_to_index_.end()) {
  1998. LinkDataArrowForSwitchActor(switch_actor, call_node_with_index.second, to_actor, to_index,
  1999. switch_actor->branch_id_to_index_[branch_id]);
  2000. continue;
  2001. }
  2002. LinkDataArrowForSwitchActor(switch_actor, call_node_with_index.second, to_actor, to_index, branch_index);
  2003. switch_actor->branch_id_to_index_[branch_id] = branch_index;
  2004. }
  2005. }
  2006. void GraphScheduler::LinkDataArrowForSwitchActor(SwitchActor *from_actor, const size_t from_index,
  2007. OpActor<DeviceTensor> *to_actor, const size_t to_index,
  2008. const size_t branch_index) {
  2009. MS_EXCEPTION_IF_NULL(from_actor);
  2010. MS_EXCEPTION_IF_NULL(to_actor);
  2011. size_t start_branch = 0;
  2012. size_t max_branch = from_actor->output_branch_arrows_.size();
  2013. if (branch_index != SIZE_MAX) {
  2014. start_branch = branch_index;
  2015. max_branch = branch_index + 1;
  2016. }
  2017. for (size_t i = start_branch; i < max_branch; ++i) {
  2018. if (from_actor->branch_inputs_pos_[i].size() <= from_index) {
  2019. MS_LOG(EXCEPTION) << "No input for switch actor:" << from_actor->GetAID() << " branch:" << i
  2020. << " from index:" << from_index << " output size:" << from_actor->branch_inputs_pos_[i].size()
  2021. << " to actor:" << to_actor->GetAID() << " to index:" << to_index;
  2022. }
  2023. auto op_arrow =
  2024. std::make_shared<DataArrow>(from_actor->branch_inputs_pos_[i][from_index], to_actor->GetAID(), to_index);
  2025. (void)from_actor->output_branch_arrows_[i].emplace_back(op_arrow);
  2026. }
  2027. }
  2028. void GraphScheduler::LinkDataArrowByControlNode(const GraphCompilerInfo &graph_compiler_info,
  2029. const KernelWithIndex &input_with_index,
  2030. const FuncGraphPtr &from_func_graph,
  2031. OpActor<DeviceTensor> *const to_actor, const size_t to_index) {
  2032. const auto &parameters = graph_compiler_info.origin_parameters_order_;
  2033. const auto &front_to_backend_parameter = graph_compiler_info.control_node_parser_->front_to_backend_parameters_;
  2034. const auto &input_node = input_with_index.first;
  2035. if (IsCallNode(input_node)) {
  2036. // The actor input is a call node.
  2037. LinkDataArrowByCallInput(input_with_index, graph_compiler_info.control_node_parser_, from_func_graph, to_actor,
  2038. to_index);
  2039. } else if (IsGatherActor(input_node, actor_name_to_actor_)) {
  2040. // The actor input is a parameter in gather actor.
  2041. auto from_actor = dynamic_cast<GatherActor *>(actor_name_to_actor_[input_node->func_graph()->ToString()]);
  2042. auto position = from_actor->FetchDataNodePosition({input_node, 0});
  2043. auto op_arrow = std::make_shared<DataArrow>(position, to_actor->GetAID(), to_index);
  2044. (void)from_actor->output_data_arrows_.emplace_back(op_arrow);
  2045. } else if (IsSwitchActor(input_node)) {
  2046. const auto &actor_name = input_node->DebugString();
  2047. auto actor = FetchActor(actor_name);
  2048. MS_EXCEPTION_IF_NULL(actor);
  2049. LinkDataArrowForSwitchActor(dynamic_cast<SwitchActor *>(actor), 0, to_actor, to_index);
  2050. } else if (IsKernelActor(input_node, graph_compiler_info.strategy_)) {
  2051. // The actor input is a cnode.
  2052. if (front_node_to_actor_.find(input_node) == front_node_to_actor_.end()) {
  2053. const auto &kernel_with_index = AnfAlgo::VisitKernelWithReturnType(input_node, 0);
  2054. const auto &backend_node =
  2055. graph_compiler_info.control_node_parser_->GetBackendKernelByFrontKernel(kernel_with_index);
  2056. if (backend_node.first == nullptr) {
  2057. MS_LOG(EXCEPTION) << "Cannot find actor:" << to_actor->GetAID()
  2058. << " input_node:" << AnfAlgo::GetNodeDebugString(input_node) << " addr:" << input_node;
  2059. }
  2060. const auto &actor_name = backend_node.first->fullname_with_scope();
  2061. const auto &actor = FetchActor(actor_name);
  2062. MS_EXCEPTION_IF_NULL(actor);
  2063. auto from_actor = dynamic_cast<KernelActor *>(actor);
  2064. MS_EXCEPTION_IF_NULL(from_actor);
  2065. auto op_arrow = std::make_shared<DataArrow>(backend_node.second, to_actor->GetAID(), to_index);
  2066. (void)from_actor->output_data_arrows_.emplace_back(op_arrow);
  2067. auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, backend_node.second, false);
  2068. UpdateRefCount(device_tensor.get(), true);
  2069. return;
  2070. }
  2071. auto op_arrow = std::make_shared<DataArrow>(input_with_index.second, to_actor->GetAID(), to_index);
  2072. auto from_actor = front_node_to_actor_[input_node];
  2073. (void)from_actor->output_data_arrows_.emplace_back(op_arrow);
  2074. auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->kernel_, input_with_index.second, false);
  2075. UpdateRefCount(device_tensor.get(), true);
  2076. } else if (find(parameters.begin(), parameters.end(), input_node) != parameters.end()) {
  2077. // The actor input is a parameter in host data source actor.
  2078. std::string actor_name = graph_compiler_info.name_ + "_HostDSActor";
  2079. auto actor = FetchActor(actor_name);
  2080. MS_EXCEPTION_IF_NULL(actor);
  2081. auto from_actor = dynamic_cast<HostQueueDataSourceActor *>(actor);
  2082. MS_EXCEPTION_IF_NULL(from_actor);
  2083. auto backend_iter = front_to_backend_parameter.find(input_node);
  2084. if (backend_iter == front_to_backend_parameter.end()) {
  2085. MS_LOG(EXCEPTION) << "Cannot find backend node for front node:" << AnfAlgo::GetNodeDebugString(input_node);
  2086. }
  2087. const auto &backend_node = backend_iter->second.first;
  2088. auto iter = from_actor->data_node_position_map_.find(input_node);
  2089. if (iter == from_actor->data_node_position_map_.end()) {
  2090. MS_LOG(EXCEPTION) << "Cannot find data node in data source actor, backend node:"
  2091. << AnfAlgo::GetNodeDebugString(backend_node)
  2092. << " front node:" << AnfAlgo::GetNodeDebugString(input_node);
  2093. }
  2094. auto op_arrow = std::make_shared<DataArrow>(iter->second, to_actor->GetAID(), to_index);
  2095. (void)from_actor->output_data_arrows_.emplace_back(op_arrow);
  2096. auto device_tensor = AnfAlgo::GetMutableOutputAddr(from_actor->data_nodes_[iter->second], 0, false);
  2097. UpdateRefCount(device_tensor.get(), true);
  2098. } else {
  2099. MS_LOG(EXCEPTION) << "Cannot find actor of switch input_node:" << AnfAlgo::GetNodeDebugString(input_node)
  2100. << " to actor:" << to_actor->GetAID();
  2101. }
  2102. }
  2103. void GraphScheduler::LinkDataArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info,
  2104. SwitchActor *const actor) {
  2105. // Link switch input.
  2106. const auto &inputs = actor->input_nodes_;
  2107. for (size_t i = 0; i < inputs.size(); ++i) {
  2108. auto input = inputs[i];
  2109. if (input.first->isa<ValueNode>() || (input.first->isa<Parameter>() && HasAbstractRef(input.first))) {
  2110. continue;
  2111. }
  2112. const FuncGraphPtr from_func_graph = actor->node_->func_graph();
  2113. LinkDataArrowByControlNode(graph_compiler_info, input, from_func_graph, actor, i);
  2114. }
  2115. // Link switch output.
  2116. for (size_t i = 0; i < actor->branch_func_graph_.size(); ++i) {
  2117. auto func_graph = actor->branch_func_graph_[i];
  2118. if (func_graph == nullptr) {
  2119. continue;
  2120. }
  2121. auto gather_name = func_graph->ToString();
  2122. if (actor_name_to_actor_.find(gather_name) == actor_name_to_actor_.end()) {
  2123. MS_LOG(EXCEPTION) << "Cannot find gather actor for funcgraph:" << gather_name
  2124. << ",switch input size:" << actor->input_nodes_.size();
  2125. }
  2126. auto to_actor = dynamic_cast<GatherActor *>(actor_name_to_actor_[gather_name]);
  2127. for (size_t j = 0; j < actor->branch_inputs_pos_[i].size(); ++j) {
  2128. auto pos = actor->branch_inputs_pos_[i][j];
  2129. auto to_actor_index = j;
  2130. auto op_arrow = std::make_shared<DataArrow>(pos, to_actor->GetAID(), to_actor_index);
  2131. (void)actor->output_branch_arrows_[i].emplace_back(op_arrow);
  2132. }
  2133. }
  2134. }
  2135. void GraphScheduler::LinkControlArrowForGatherActor(std::vector<KernelActorPtr> *const kernel_actors,
  2136. const std::vector<KernelGraphPtr> &graphs,
  2137. const ControlNodeParserPtr &parser) {
  2138. // Link control arrow to kernel actor.
  2139. for (size_t i = 0; i < graphs.size(); ++i) {
  2140. const auto &kernel_graph = graphs[i];
  2141. MS_EXCEPTION_IF_NULL(kernel_graph);
  2142. const auto &func_graph = kernel_graph->GetFuncGraph();
  2143. if (func_graph == nullptr) {
  2144. continue;
  2145. }
  2146. const auto &actor = FetchActor(func_graph->ToString());
  2147. if (actor == nullptr) {
  2148. continue;
  2149. }
  2150. const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
  2151. MS_EXCEPTION_IF_NULL(gather_actor);
  2152. // When gather actor is not empty, it means the control arrow of no input kernel actor needs to be sent by gather.
  2153. for (const auto &kernel : kernel_graph->execution_order()) {
  2154. if (IsKernelActor(kernel) && (!IsSkippedKernelActor(kernel))) {
  2155. const auto &kernel_actor = dynamic_cast<KernelActor *>(FetchActor(kernel->fullname_with_scope()));
  2156. MS_EXCEPTION_IF_NULL(kernel_actor);
  2157. if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) {
  2158. (void)gather_actor->output_control_arrows_.emplace_back(kernel_actor->GetAID());
  2159. kernel_actor->input_controls_num_ = 1;
  2160. }
  2161. }
  2162. }
  2163. }
  2164. for (auto &kernel_actor : *kernel_actors) {
  2165. MS_EXCEPTION_IF_NULL(kernel_actor);
  2166. if ((kernel_actor->output_data_arrows_.size() == 0) && (kernel_actor->output_control_arrows_.size() == 0) &&
  2167. !parser->IsKernelInRootFuncGraph(kernel_actor->kernel_)) {
  2168. // Check whether the kernel actor belongs to the root graph.
  2169. // In general, all no output nodes belong to the root funcgraph, and the corresponding switch actor for output
  2170. // should be empty. In control flow, the control arrow of the no output node in the sub funcgraph should be
  2171. // sent to the output switch actor.
  2172. const auto &graph = kernel_actor->kernel_->func_graph();
  2173. OpActor<DeviceTensor> *actor = nullptr;
  2174. if (graph != nullptr) {
  2175. const auto &kernel_graph = dynamic_cast<KernelGraph *>(graph.get());
  2176. const auto func_graph = kernel_graph->GetFuncGraph();
  2177. if (func_graph != nullptr) {
  2178. actor = FetchActor(func_graph->get_return()->DebugString());
  2179. if (actor != nullptr) {
  2180. auto switch_actor = dynamic_cast<SwitchActor *>(actor);
  2181. MS_EXCEPTION_IF_NULL(switch_actor);
  2182. (void)kernel_actor->output_control_arrows_.emplace_back(switch_actor->GetAID());
  2183. switch_actor->input_controls_num_++;
  2184. }
  2185. }
  2186. }
  2187. }
  2188. }
  2189. // Link input auto monad control arrow from kernel actor to gather actor.
  2190. const auto &monad_nodes = parser->kernel_to_call_nodes_;
  2191. for (const auto node_pair : monad_nodes) {
  2192. const auto &kernel_actor_name = node_pair.first->fullname_with_scope();
  2193. const auto &gather_actor_name = node_pair.second->DebugString();
  2194. auto kernel_op_actor = FetchActor(kernel_actor_name);
  2195. auto gather_op_actor = FetchActor(gather_actor_name);
  2196. if (kernel_op_actor == nullptr || gather_op_actor == nullptr) {
  2197. continue;
  2198. }
  2199. auto kernel_actor = dynamic_cast<KernelActor *>(kernel_op_actor);
  2200. auto gather_actor = dynamic_cast<GatherActor *>(gather_op_actor);
  2201. (void)kernel_actor->output_control_arrows_.emplace_back(gather_actor->GetAID());
  2202. gather_actor->input_controls_num_++;
  2203. }
  2204. }
  2205. void GraphScheduler::LinkControlArrowForSwitchActor(std::vector<SwitchActorPtr> *const switch_actors,
  2206. LoopCountActor *const to_actor,
  2207. const KernelMapPosition &origin_outputs_order) {
  2208. if (to_actor == nullptr || (*switch_actors).empty()) {
  2209. return;
  2210. }
  2211. // If there is no output from the switch actor branch, it means that the subgraph has no input,
  2212. // and need to connect a control arrow to the corresponding gather actor.
  2213. for (auto &switch_actor : (*switch_actors)) {
  2214. if (AnfAlgo::CheckPrimitiveType(switch_actor->node_, prim::kPrimReturn)) {
  2215. const auto &func_graph = switch_actor->node_->func_graph();
  2216. if (func_graph->output()->isa<ValueNode>()) {
  2217. const auto &actor_name = func_graph->ToString();
  2218. auto actor = FetchActor(actor_name);
  2219. MS_EXCEPTION_IF_NULL(actor);
  2220. auto gather_actor = dynamic_cast<GatherActor *>(actor);
  2221. MS_EXCEPTION_IF_NULL(gather_actor);
  2222. (void)gather_actor->output_control_arrows_.emplace_back(switch_actor->GetAID());
  2223. switch_actor->input_controls_num_++;
  2224. }
  2225. }
  2226. for (size_t i = 0; i < switch_actor->output_branch_arrows_.size(); ++i) {
  2227. const auto &arrows = switch_actor->output_branch_arrows_[i];
  2228. if (arrows.empty() && switch_actor->branch_func_graph_[i] != nullptr) {
  2229. const auto &actor_name = switch_actor->branch_func_graph_[i]->ToString();
  2230. const auto &actor = FetchActor(actor_name);
  2231. if (actor != nullptr) {
  2232. const auto &gather_actor = dynamic_cast<GatherActor *>(actor);
  2233. MS_EXCEPTION_IF_NULL(gather_actor);
  2234. (void)switch_actor->output_branch_control_arrows_[i].emplace_back(gather_actor->GetAID());
  2235. gather_actor->input_controls_num_++;
  2236. }
  2237. }
  2238. }
  2239. }
  2240. // Collect all the call node in outputs.
  2241. std::set<AnfNodePtr> call_nodes;
  2242. for (const auto &output : origin_outputs_order) {
  2243. if (IsCallNode(output.first.first)) {
  2244. (void)call_nodes.insert(output.first.first);
  2245. }
  2246. }
  2247. to_actor->input_controls_num_ += call_nodes.size();
  2248. // Link the output switch actor of the subgraph to the output actor.
  2249. for (const auto &call_node : call_nodes) {
  2250. const auto &func_graphs = FetchFuncGraphbyCallNode(call_node);
  2251. for (const auto func_graph : func_graphs) {
  2252. MS_EXCEPTION_IF_NULL(func_graph);
  2253. const auto &actor_name = func_graph->get_return()->DebugString();
  2254. auto actor = FetchActor(actor_name);
  2255. MS_EXCEPTION_IF_NULL(actor);
  2256. auto switch_actor = dynamic_cast<SwitchActor *>(actor);
  2257. MS_EXCEPTION_IF_NULL(switch_actor);
  2258. size_t branch_index = switch_actor->branch_id_to_index_.size();
  2259. if (switch_actor->branch_id_to_index_.find(kMainBranchID) != switch_actor->branch_id_to_index_.end()) {
  2260. branch_index = switch_actor->branch_id_to_index_[kMainBranchID];
  2261. } else {
  2262. switch_actor->branch_id_to_index_[kMainBranchID] = branch_index;
  2263. }
  2264. (void)switch_actor->output_branch_control_arrows_[branch_index].emplace_back(to_actor->GetAID());
  2265. }
  2266. }
  2267. }
  2268. void GraphScheduler::LinkBranchArrowForSwitchActor(const GraphCompilerInfo &graph_compiler_info) {
  2269. for (const auto &control_node : graph_compiler_info.control_nodes_) {
  2270. if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitch) ||
  2271. AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimSwitchLayer)) {
  2272. const auto &actor_name = control_node->DebugString();
  2273. auto actor = FetchActor(actor_name);
  2274. MS_EXCEPTION_IF_NULL(actor);
  2275. auto switch_actor = dynamic_cast<SwitchActor *>(actor);
  2276. MS_EXCEPTION_IF_NULL(switch_actor);
  2277. for (size_t i = 0; i < switch_actor->branch_func_graph_.size(); ++i) {
  2278. const auto &func_graph = switch_actor->branch_func_graph_[i];
  2279. if (func_graph == nullptr) {
  2280. continue;
  2281. }
  2282. const auto &gather_actor = FetchActor(func_graph->ToString());
  2283. MS_EXCEPTION_IF_NULL(gather_actor);
  2284. (void)switch_actor->output_branch_branch_arrows_[i].emplace_back(gather_actor->GetAID());
  2285. }
  2286. }
  2287. }
  2288. }
  2289. void GraphScheduler::LinkBranchArrowForGatherActor(const GraphCompilerInfo &graph_compiler_info) {
  2290. if (graph_compiler_info.control_nodes_.empty()) {
  2291. return;
  2292. }
  2293. // Link branch arrow from gather actor to gather actor.
  2294. for (const auto &control_node : graph_compiler_info.control_nodes_) {
  2295. const auto &cnode = control_node->cast<CNodePtr>();
  2296. const auto &inputs = cnode->inputs();
  2297. if (inputs[0]->isa<ValueNode>() && IsValueNode<FuncGraph>(inputs[0])) {
  2298. const auto &actor_name = control_node->DebugString();
  2299. auto actor = FetchActor(actor_name);
  2300. MS_EXCEPTION_IF_NULL(actor);
  2301. auto gather_actor = dynamic_cast<GatherActor *>(actor);
  2302. MS_EXCEPTION_IF_NULL(gather_actor);
  2303. (void)gather_actor->output_branch_arrows_.emplace_back(gather_actor->gather_aid_);
  2304. }
  2305. }
  2306. // Link branch arrow from gather actor to switch actor.
  2307. for (const auto &func_graph_with_call_num : graph_compiler_info.control_node_parser_->func_graph_to_call_num_) {
  2308. const auto &actor_name = func_graph_with_call_num.first->ToString();
  2309. auto actor = FetchActor(actor_name);
  2310. MS_EXCEPTION_IF_NULL(actor);
  2311. auto gather_actor = dynamic_cast<GatherActor *>(actor);
  2312. MS_EXCEPTION_IF_NULL(gather_actor);
  2313. (void)gather_actor->output_branch_arrows_.emplace_back(gather_actor->switch_aid_);
  2314. }
  2315. }
  2316. bool GraphScheduler::CheckActorValid(const ActorSet *actor_set, GraphExecutionStrategy strategy) const {
  2317. MS_EXCEPTION_IF_NULL(actor_set);
  2318. // Check the data source actors.
  2319. for (const auto &data_source_actor : actor_set->data_source_actors_) {
  2320. MS_EXCEPTION_IF_NULL(data_source_actor);
  2321. if (data_source_actor->output_data_arrows_.size() + data_source_actor->output_result_arrows_.size() +
  2322. data_source_actor->output_control_arrows_.size() ==
  2323. 0) {
  2324. MS_LOG(ERROR) << data_source_actor->GetAID().Name() << " has no user.";
  2325. return false;
  2326. }
  2327. }
  2328. if (strategy == GraphExecutionStrategy::kStep) {
  2329. return true;
  2330. }
  2331. // Check the kernel actors.
  2332. for (const auto &kernel_actor : actor_set->kernel_actors_) {
  2333. MS_EXCEPTION_IF_NULL(kernel_actor);
  2334. if (kernel_actor->output_data_arrows_.size() + kernel_actor->output_control_arrows_.size() == 0) {
  2335. MS_LOG(ERROR) << kernel_actor->GetAID().Name() << " has no user.";
  2336. return false;
  2337. }
  2338. auto input_num = AnfAlgo::GetInputTensorNum(kernel_actor->kernel_);
  2339. auto input_data_num = kernel_actor->input_datas_num_;
  2340. auto device_tensor_store_num = kernel_actor->device_tensor_store_keys_.size();
  2341. if (input_data_num + device_tensor_store_num != input_num) {
  2342. MS_LOG(ERROR) << "The input building of " << AnfAlgo::GetNodeDebugString(kernel_actor->kernel_)
  2343. << " is wrong, input data num: " << input_data_num
  2344. << ", device tensor store num: " << device_tensor_store_num << ", total input num: " << input_num;
  2345. return false;
  2346. }
  2347. }
  2348. // Check the copy actors.
  2349. for (const auto &copy_actor : actor_set->copy_actors_) {
  2350. MS_EXCEPTION_IF_NULL(copy_actor);
  2351. if (copy_actor->output_data_arrows_.size() + copy_actor->output_control_arrows_.size() == 0) {
  2352. MS_LOG(ERROR) << copy_actor->GetAID().Name() << " has no user.";
  2353. return false;
  2354. }
  2355. const size_t kCopyActorInputDataNum = 1;
  2356. auto input_data_num = copy_actor->input_datas_num_;
  2357. size_t device_tensor_store_num = copy_actor->device_tensor_store_keys_.size();
  2358. if (input_data_num + device_tensor_store_num != kCopyActorInputDataNum) {
  2359. MS_LOG(ERROR) << "The input building of " << copy_actor->GetAID().Name()
  2360. << " is wrong, input data num: " << input_data_num
  2361. << ", device tensor store num: " << device_tensor_store_num
  2362. << ", total input num: " << kCopyActorInputDataNum;
  2363. return false;
  2364. }
  2365. }
  2366. // Check the loop count actor.
  2367. const auto &loop_count_actor = actor_set->loop_count_actor_;
  2368. if ((loop_count_actor != nullptr) &&
  2369. (actor_set->data_source_actors_.size() + actor_set->kernel_actors_.size() + actor_set->copy_actors_.size() > 0)) {
  2370. if (loop_count_actor->input_controls_num_ == 0) {
  2371. MS_LOG(ERROR) << loop_count_actor->GetAID().Name() << " has no source.";
  2372. return false;
  2373. }
  2374. }
  2375. return true;
  2376. }
  2377. void GraphScheduler::PersistDeviceTensor(const GraphCompilerInfo &graph_compiler_info) {
  2378. for (size_t i = 0; i < graph_compiler_info.graphs_.size(); ++i) {
  2379. const auto &graph = graph_compiler_info.graphs_[i];
  2380. const auto &device_context = graph_compiler_info.device_contexts_[i];
  2381. MS_EXCEPTION_IF_NULL(graph);
  2382. MS_EXCEPTION_IF_NULL(device_context);
  2383. for (auto &value_node : graph->graph_value_nodes()) {
  2384. MS_EXCEPTION_IF_NULL(value_node);
  2385. if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
  2386. MS_LOG(INFO) << "The device address is not exist: " << value_node->ToString();
  2387. continue;
  2388. }
  2389. auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0, false);
  2390. const auto &front_node = FetchFrontNodeByBackendNode(value_node, graph);
  2391. DeviceTensorStore::GetInstance().Insert(front_node.get(), device_tensor);
  2392. UpdateRefCount(device_tensor.get(), true);
  2393. }
  2394. for (auto &input_node : graph->input_nodes()) {
  2395. MS_EXCEPTION_IF_NULL(input_node);
  2396. AnfNodePtr sub_front_node = nullptr;
  2397. if (IsInternalParameter(input_node, graph)) {
  2398. auto front_output_with_index = graph->GetFrontNodeByInternalParameter(input_node);
  2399. sub_front_node = front_output_with_index.first;
  2400. } else if (IsPersistentDeviceTensor(input_node) || HasAbstractRef(input_node)) {
  2401. sub_front_node = FetchFrontNodeByBackendNode(input_node, graph);
  2402. }
  2403. if (sub_front_node == nullptr) {
  2404. continue;
  2405. }
  2406. // The sub front nodes share the device tensor store with the root front node.
  2407. auto front_node = sub_front_node;
  2408. if (graph_compiler_info.control_node_parser_ != nullptr) {
  2409. front_node = graph_compiler_info.control_node_parser_->FetchRootGraphFrontNodeBySubFrontNode(sub_front_node);
  2410. }
  2411. MS_LOG(DEBUG) << "Graph id:" << graph->graph_id() << ", sub front node:" << sub_front_node->DebugString()
  2412. << ", root front node:" << front_node->DebugString();
  2413. auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0, false);
  2414. MS_EXCEPTION_IF_NULL(device_tensor);
  2415. if (IsPersistentDeviceTensor(input_node)) {
  2416. DeviceTensorStore::GetInstance().Insert(front_node.get(), device_tensor);
  2417. UpdateRefCount(device_tensor.get(), true);
  2418. }
  2419. // Share the weight in the host and device, then input_node is internal parameter and front_node is weight.
  2420. if (!IsPersistentDeviceTensor(front_node)) {
  2421. continue;
  2422. }
  2423. // If the device tensor store of this device type is not exist, then create the new device tensor of this type.
  2424. if (DeviceTensorStore::GetInstance().Fetch(front_node.get(), device_context->GetDeviceAddressType()) == nullptr) {
  2425. MS_LOG(INFO) << "Fetch no device tensor store by:" << front_node->fullname_with_scope()
  2426. << ", type:" << device_context->GetDeviceAddressType();
  2427. auto other_type_device_tensor = device_context->CreateDeviceAddress(
  2428. nullptr, device_tensor->GetSize(), device_tensor->format(), device_tensor->type_id());
  2429. DeviceTensorStore::GetInstance().Insert(front_node.get(), other_type_device_tensor);
  2430. UpdateRefCount(other_type_device_tensor.get(), true);
  2431. }
  2432. }
  2433. }
  2434. // In control flow, there may be some value nodes that is not in the kernel graph and needs to be placed
  2435. // in the tensor store separately.
  2436. for (const auto &value_node : graph_compiler_info.control_node_parser_->front_value_nodes_) {
  2437. MS_EXCEPTION_IF_NULL(value_node.first);
  2438. auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node.first, 0, false);
  2439. DeviceTensorStore::GetInstance().Insert(value_node.first.get(), device_tensor);
  2440. UpdateRefCount(device_tensor.get(), true);
  2441. }
  2442. }
  2443. HostTensorQueue *GraphScheduler::FetchHostQueue(const ActorInfo &actor_info) const {
  2444. const auto &iter = actor_to_host_queue_.find(actor_info);
  2445. if (iter != actor_to_host_queue_.end()) {
  2446. return iter->second.get();
  2447. } else {
  2448. return nullptr;
  2449. }
  2450. }
  2451. void GraphScheduler::FetchKernelTransformTypeAndName(const AnfNodePtr &node, const KernelGraphPtr &graph,
  2452. const GraphCompilerInfo &graph_compiler_info,
  2453. KernelTransformType *const kernel_type,
  2454. std::string *const kernel_name) {
  2455. MS_EXCEPTION_IF_NULL(node);
  2456. MS_EXCEPTION_IF_NULL(graph);
  2457. MS_EXCEPTION_IF_NULL(kernel_type);
  2458. MS_EXCEPTION_IF_NULL(kernel_name);
  2459. if (IsDeviceQueueDSActor(node, graph_compiler_info.strategy_)) {
  2460. *kernel_type = KernelTransformType::kDeviceDataSourceActor;
  2461. *kernel_name = graph_compiler_info.name_ + "_DeviceDSActor" + "_" + std::to_string(graph->graph_id());
  2462. } else if (IsHostQueueDSActor(node, graph, graph_compiler_info.origin_parameters_order_,
  2463. graph_compiler_info.strategy_)) {
  2464. *kernel_type = KernelTransformType::kHostDataSourceActor;
  2465. *kernel_name = graph_compiler_info.name_ + "_HostDSActor";
  2466. } else if (IsKernelActor(node, graph_compiler_info.strategy_)) {
  2467. *kernel_type = KernelTransformType::kKernelActor;
  2468. *kernel_name = node->fullname_with_scope();
  2469. } else if (IsInternalParameter(node, graph)) {
  2470. *kernel_type = KernelTransformType::kInternalParameter;
  2471. *kernel_name = "";
  2472. } else if (IsPersistentDeviceTensor(node)) {
  2473. *kernel_type = KernelTransformType::kDeviceTensorStore;
  2474. *kernel_name = "";
  2475. } else {
  2476. // May exist the from kernel that no need link in the pynative mode.
  2477. MS_LOG(DEBUG) << "Invalid from kernel: " << node->fullname_with_scope();
  2478. *kernel_type = KernelTransformType::kUnknown;
  2479. *kernel_name = "";
  2480. }
  2481. }
  2482. void GraphScheduler::InsertActor(OpActor<DeviceTensor> *actor) {
  2483. MS_EXCEPTION_IF_NULL(actor);
  2484. if (actor_name_to_actor_.count(actor->GetAID().Name()) > 0) {
  2485. MS_LOG(EXCEPTION) << "The actor already exists: " << actor->GetAID().Name();
  2486. }
  2487. actor_name_to_actor_[actor->GetAID().Name()] = actor;
  2488. }
  2489. OpActor<DeviceTensor> *GraphScheduler::FetchActor(const std::string &actor_name) const {
  2490. const auto &iter = actor_name_to_actor_.find(actor_name);
  2491. if (iter == actor_name_to_actor_.end()) {
  2492. return nullptr;
  2493. }
  2494. return iter->second;
  2495. }
  2496. void GraphScheduler::DumpActor(const ActorSet *actor_set, const GraphCompilerInfo &graph_compiler_info) const {
  2497. MS_EXCEPTION_IF_NULL(actor_set);
  2498. const auto &context_ptr = MsContext::GetInstance();
  2499. MS_EXCEPTION_IF_NULL(context_ptr);
  2500. auto save_graphs = context_ptr->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG);
  2501. if (!save_graphs) {
  2502. return;
  2503. }
  2504. std::string filename = GetSaveGraphsPathName("actor_set_" + actor_set->name_ + ".ir");
  2505. std::ofstream ofs(filename);
  2506. if (!ofs.is_open()) {
  2507. MS_LOG(ERROR) << "Open file [" << filename << "] failed!";
  2508. return;
  2509. }
  2510. ofs << "[Device tensor stores]\n";
  2511. DumpDeviceTensorStore(graph_compiler_info, ofs);
  2512. ofs << "\n\n[Data source actors:" << actor_set->data_source_actors_.size() << "]\n";
  2513. for (const auto &data_source_actor : actor_set->data_source_actors_) {
  2514. DumpDSActor(data_source_actor.get(), ofs);
  2515. }
  2516. ofs << "\n\n[Kernel actors:" << actor_set->kernel_actors_.size() << "]\n";
  2517. for (const auto &kernel_actor : actor_set->kernel_actors_) {
  2518. DumpKernelActor(kernel_actor.get(), ofs);
  2519. }
  2520. ofs << "\n\n[No input kernel actors:" << actor_set->no_input_kernel_actors_.size() << "]\n";
  2521. for (const auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
  2522. DumpKernelActor(no_input_kernel_actor.get(), ofs);
  2523. }
  2524. ofs << "\n\n[Copy actors:" << actor_set->copy_actors_.size() << "]\n";
  2525. for (const auto &copy_actor : actor_set->copy_actors_) {
  2526. DumpCopyActor(copy_actor.get(), ofs);
  2527. }
  2528. ofs << "\n\n[Gather actors:" << actor_set->gather_actors_.size() << "]\n";
  2529. for (const auto &gather_actor : actor_set->gather_actors_) {
  2530. DumpGatherActor(gather_actor.get(), ofs);
  2531. }
  2532. ofs << "\n\n[Switch actors:" << actor_set->switch_actors_.size() << "]\n";
  2533. for (const auto &switch_actor : actor_set->switch_actors_) {
  2534. DumpSwitchActor(switch_actor.get(), ofs);
  2535. }
  2536. const auto &loop_count_actor = actor_set->loop_count_actor_;
  2537. ofs << "\n\n[Loop count actor:" << (loop_count_actor != nullptr ? 1 : 0) << "]\n";
  2538. if (loop_count_actor != nullptr) {
  2539. DumpLoopCountActor(loop_count_actor.get(), ofs);
  2540. }
  2541. const auto &output_actor = actor_set->output_actor_;
  2542. ofs << "\n\n[Output actor:" << (output_actor != nullptr ? 1 : 0) << "]\n";
  2543. if (output_actor != nullptr) {
  2544. DumpOutputActor(output_actor.get(), ofs);
  2545. }
  2546. }
  2547. void GraphScheduler::DumpAbstractActor(const AbstractActor *actor, std::ofstream &ofs) const {
  2548. MS_EXCEPTION_IF_NULL(actor);
  2549. ofs << "\t\tdevice_contexts_num:" << actor->device_contexts_.size()
  2550. << "\tdevice_tensor_store_keys_num:" << actor->device_tensor_store_keys_.size()
  2551. << "\tinput_data_arrow_actors_num:" << actor->input_datas_num_
  2552. << "\tinput_control_arrow_actors_num:" << actor->input_controls_num_ << "\n";
  2553. ofs << "\t\toutput_data_arrows_num:" << actor->output_data_arrows_.size()
  2554. << "\toutput_control_arrows_num:" << actor->output_control_arrows_.size()
  2555. << "\toutput_result_arrows_num:" << actor->output_result_arrows_.size() << "\n";
  2556. if (actor->device_contexts_.size() > 0) {
  2557. ofs << "\t\tdevice_contexts:" << actor->device_contexts_.size() << "\n ";
  2558. for (const auto &device_context : actor->device_contexts_) {
  2559. if (device_context == nullptr) {
  2560. ofs << "\t\t\tdevice_context:" << device_context << "\n";
  2561. continue;
  2562. }
  2563. ofs << "\t\t\tdevice_context:" << device_context->device_context_key().ToString() << "\n";
  2564. }
  2565. }
  2566. if (actor->device_tensor_store_keys_.size() > 0) {
  2567. ofs << "\t\tdevice_tensor_store_keys:" << actor->device_tensor_store_keys_.size() << "\n ";
  2568. for (const auto &device_tensor_store_key : actor->device_tensor_store_keys_) {
  2569. MS_EXCEPTION_IF_NULL(device_tensor_store_key.second);
  2570. ofs << "\t\t\tto_input_index:" << device_tensor_store_key.first
  2571. << "\tfrom_node_name:" << device_tensor_store_key.second->fullname_with_scope() << "\n";
  2572. }
  2573. }
  2574. if (actor->input_data_arrow_aids_.size() > 0) {
  2575. ofs << "\t\tinput_data_arrow_actors:" << actor->input_data_arrow_aids_.size() << "\n ";
  2576. for (const auto &input_data_arrow_aid : actor->input_data_arrow_aids_) {
  2577. ofs << "\t\t\tfrom_actor_name:" << input_data_arrow_aid.Name() << "\n";
  2578. }
  2579. }
  2580. if (actor->input_control_arrow_aids_.size() > 0) {
  2581. ofs << "\t\tinput_control_arrow_actors:" << actor->input_control_arrow_aids_.size() << "\n ";
  2582. for (const auto &input_control_arrow_aid : actor->input_control_arrow_aids_) {
  2583. ofs << "\t\t\tfrom_actor_name:" << input_control_arrow_aid.Name() << "\n";
  2584. }
  2585. }
  2586. const auto &output_data_arrows = actor->output_data_arrows();
  2587. if (output_data_arrows.size() > 0) {
  2588. ofs << "\t\toutput_data_arrows:" << output_data_arrows.size() << "\n ";
  2589. for (const auto &data_arrow : output_data_arrows) {
  2590. MS_EXCEPTION_IF_NULL(data_arrow);
  2591. ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
  2592. << "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
  2593. << "\n";
  2594. }
  2595. }
  2596. const auto &output_control_arrows = actor->output_control_arrows();
  2597. if (output_control_arrows.size() > 0) {
  2598. ofs << "\t\toutput_control_arrows:" << output_control_arrows.size() << "\n ";
  2599. for (const auto &aid : output_control_arrows) {
  2600. ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
  2601. }
  2602. }
  2603. if (actor->output_result_arrows_.size() > 0) {
  2604. ofs << "\t\toutput_result_arrows:" << actor->output_result_arrows_.size() << "\n ";
  2605. for (const auto &result_arrow : actor->output_result_arrows_) {
  2606. MS_EXCEPTION_IF_NULL(result_arrow);
  2607. ofs << "\t\t\tfrom_output_index:" << result_arrow->from_output_index_
  2608. << "\tto_actor_name:" << result_arrow->to_op_id_.Name()
  2609. << "\toutput_node_position:" << result_arrow->to_input_index_ << "\n";
  2610. }
  2611. }
  2612. }
  2613. void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const {
  2614. MS_EXCEPTION_IF_NULL(actor);
  2615. const auto &actor_name = actor->GetAID().Name();
  2616. ofs << "\tactor_name:" << actor_name << "\n";
  2617. if (actor_name.find("_DeviceDSActor") != string::npos) {
  2618. // Dump the member info of device queue data source actor.
  2619. const auto &device_queue_ds_actor = dynamic_cast<const DeviceQueueDataSourceActor *>(actor);
  2620. const auto &data_kernel = device_queue_ds_actor->data_kernel_;
  2621. MS_EXCEPTION_IF_NULL(data_kernel);
  2622. ofs << "\t\tdata_kernel_name:" << data_kernel->fullname_with_scope()
  2623. << "\tinput_number:" << AnfAlgo::GetInputTensorNum(data_kernel)
  2624. << "\toutput_number:" << AnfAlgo::GetOutputTensorNum(data_kernel) << "\n";
  2625. for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(data_kernel); ++i) {
  2626. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(data_kernel, i, false);
  2627. MS_EXCEPTION_IF_NULL(device_tensor);
  2628. ofs << "\t\t\toutput_index:" << i << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
  2629. << "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
  2630. }
  2631. } else if (actor_name.find("_HostDSActor") != string::npos) {
  2632. // Dump the member info of host queue data source actor.
  2633. const auto &host_queue_ds_actor = dynamic_cast<const HostQueueDataSourceActor *>(actor);
  2634. ofs << "\t\tdata_nodes:" << host_queue_ds_actor->data_nodes_.size() << "\n";
  2635. for (size_t i = 0; i < host_queue_ds_actor->data_nodes_.size(); ++i) {
  2636. const auto &data_node = host_queue_ds_actor->data_nodes_[i];
  2637. MS_EXCEPTION_IF_NULL(data_node);
  2638. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(data_node, 0, false);
  2639. MS_EXCEPTION_IF_NULL(device_tensor);
  2640. ofs << "\t\t\tnode_order_number:" << i << "\tnode_name:" << data_node->fullname_with_scope()
  2641. << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
  2642. << "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n";
  2643. }
  2644. }
  2645. DumpAbstractActor(actor, ofs);
  2646. ofs << "\n";
  2647. }
  2648. void GraphScheduler::DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const {
  2649. MS_EXCEPTION_IF_NULL(actor);
  2650. ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_ << "\n";
  2651. DumpAbstractActor(actor, ofs);
  2652. ofs << "\t\toutput_control_arrows:" << (actor->data_source_aids_.size() + actor->no_input_kernel_aids_.size() + 1)
  2653. << "\n ";
  2654. for (const auto &aid : actor->data_source_aids_) {
  2655. ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
  2656. }
  2657. for (const auto &aid : actor->no_input_kernel_aids_) {
  2658. ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
  2659. }
  2660. ofs << "\t\t\tto_actor_name:" << actor->output_aid_.Name() << "\n";
  2661. ofs << "\t\tcontinuous_memory_nodes:" << actor->continuous_memory_nodes_.size() << "\n ";
  2662. for (const auto &iter : actor->continuous_memory_nodes_) {
  2663. ofs << "\t\t\tnode_name:" << iter.first.first->fullname_with_scope()
  2664. << "\tdevice_context:" << iter.first.second->device_context_key().ToString()
  2665. << "\tis_input_need:" << iter.second.first << "\tis_output_need:" << iter.second.second << "\n";
  2666. }
  2667. }
  2668. void GraphScheduler::DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const {
  2669. MS_EXCEPTION_IF_NULL(actor);
  2670. ofs << "\tactor_name:" << actor->GetAID().Name() << "\n";
  2671. const auto &kernel = actor->kernel_;
  2672. MS_EXCEPTION_IF_NULL(kernel);
  2673. ofs << "\t\tkernel_name:" << kernel->fullname_with_scope() << "\tinputs_num:" << AnfAlgo::GetInputTensorNum(kernel)
  2674. << "\toutputs_num:" << AnfAlgo::GetOutputTensorNum(kernel) << "\n";
  2675. for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel); ++i) {
  2676. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
  2677. MS_EXCEPTION_IF_NULL(device_tensor);
  2678. ofs << "\t\t\toutput_index:" << i << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
  2679. << "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
  2680. }
  2681. DumpAbstractActor(actor, ofs);
  2682. ofs << "\n";
  2683. }
  2684. void GraphScheduler::DumpOutputActor(const OutputActor *actor, std::ofstream &ofs) const {
  2685. MS_EXCEPTION_IF_NULL(actor);
  2686. ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_
  2687. << "\toutputs_num:" << actor->outputs_num_ << "\n";
  2688. DumpAbstractActor(actor, ofs);
  2689. }
  2690. void GraphScheduler::DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) const {
  2691. MS_EXCEPTION_IF_NULL(actor);
  2692. ofs << "\tactor_name:" << actor->GetAID().Name() << "\n";
  2693. auto device_tensor = actor->output_;
  2694. if (device_tensor != nullptr) {
  2695. ofs << "\t\toutput_index:" << 0 << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
  2696. << "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
  2697. }
  2698. DumpAbstractActor(actor, ofs);
  2699. ofs << "\n";
  2700. }
  2701. void GraphScheduler::DumpDeviceTensorStore(const GraphCompilerInfo &graph_compiler_info, std::ofstream &ofs) const {
  2702. for (const auto &graph : graph_compiler_info.graphs_) {
  2703. MS_EXCEPTION_IF_NULL(graph);
  2704. ofs << "\tgraph id:" << graph->graph_id() << "\n";
  2705. for (auto &value_node : graph->graph_value_nodes()) {
  2706. MS_EXCEPTION_IF_NULL(value_node);
  2707. if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
  2708. continue;
  2709. }
  2710. const auto &front_node = FetchFrontNodeByBackendNode(value_node, graph);
  2711. MS_EXCEPTION_IF_NULL(front_node);
  2712. const auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
  2713. ofs << "\t\tdevice tensor key:" << front_node->DebugString() << "\tvalue size:" << device_tensors.size() << "\n";
  2714. for (const auto &device_tensor : device_tensors) {
  2715. MS_EXCEPTION_IF_NULL(device_tensor);
  2716. ofs << "\t\t\tdevice tensor value:" << device_tensor << "\tptr:" << device_tensor->GetPtr()
  2717. << "\tsize:" << device_tensor->GetSize() << "\toriginal_ref_count:" << device_tensor->original_ref_count()
  2718. << "\tdevice_type:" << device_tensor->DeviceType() << "\n ";
  2719. }
  2720. }
  2721. for (auto &input_node : graph->input_nodes()) {
  2722. MS_EXCEPTION_IF_NULL(input_node);
  2723. if (!IsPersistentDeviceTensor(input_node)) {
  2724. continue;
  2725. }
  2726. const auto &sub_front_node = FetchFrontNodeByBackendNode(input_node, graph);
  2727. // The sub front nodes share the device tensor store with the root front node.
  2728. auto front_node = sub_front_node;
  2729. if (graph_compiler_info.control_node_parser_ != nullptr) {
  2730. front_node = graph_compiler_info.control_node_parser_->FetchRootGraphFrontNodeBySubFrontNode(sub_front_node);
  2731. }
  2732. const auto device_tensors = DeviceTensorStore::GetInstance().Fetch(front_node.get());
  2733. MS_EXCEPTION_IF_NULL(front_node);
  2734. ofs << "\t\tdevice tensor key:" << front_node->DebugString() << "\tvalue size:" << device_tensors.size() << "\n";
  2735. for (const auto &device_tensor : device_tensors) {
  2736. MS_EXCEPTION_IF_NULL(device_tensor);
  2737. ofs << "\t\t\tdevice tensor value:" << device_tensor << "\tptr:" << device_tensor->GetPtr()
  2738. << "\tsize:" << device_tensor->GetSize() << "\toriginal_ref_count:" << device_tensor->original_ref_count()
  2739. << "\tdevice_type:" << device_tensor->DeviceType() << "\n ";
  2740. }
  2741. }
  2742. ofs << "\n";
  2743. }
  2744. }
  2745. void GraphScheduler::DumpGatherActor(const GatherActor *actor, std::ofstream &ofs) const {
  2746. MS_EXCEPTION_IF_NULL(actor);
  2747. ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
  2748. ofs << "\t\tactor input num:" << actor->data_nodes_.size() << "\n";
  2749. for (const auto &node : actor->data_nodes_) {
  2750. ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node.first) << "\tindex:" << node.second << '\n';
  2751. }
  2752. ofs << "\t\tactor front to backend node:\n";
  2753. for (const auto &front_to_backend_parameter : actor->front_to_backend_parameter_) {
  2754. ofs << "\t\t\tfront node:" << AnfAlgo::GetNodeDebugString(front_to_backend_parameter.first) << '\n';
  2755. for (const auto node_with_index : front_to_backend_parameter.second) {
  2756. ofs << "\t\t\t\tbackend node:" << AnfAlgo::GetNodeDebugString(node_with_index.first)
  2757. << "\tindex:" << node_with_index.second << '\n';
  2758. }
  2759. }
  2760. ofs << "\t\tactor output data arrow:\n";
  2761. for (const auto &data_arrow : actor->output_data_arrows_) {
  2762. MS_EXCEPTION_IF_NULL(data_arrow);
  2763. ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
  2764. << "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
  2765. << "\n";
  2766. }
  2767. ofs << "\t\tactor output result arrow:\n";
  2768. for (const auto &result_arrow : actor->output_result_arrows_) {
  2769. MS_EXCEPTION_IF_NULL(result_arrow);
  2770. ofs << "\t\t\tfrom_output_index:" << result_arrow->from_output_index_
  2771. << "\tto_actor_name:" << result_arrow->to_op_id_.Name() << "\tto_input_index:" << result_arrow->to_input_index_
  2772. << "\n";
  2773. }
  2774. ofs << "\t\tactor output control arrow:\n";
  2775. for (const auto &control_arrow : actor->output_control_arrows_) {
  2776. ofs << "\t\t\tto_actor_name:" << control_arrow;
  2777. }
  2778. ofs << "\n";
  2779. }
  2780. void GraphScheduler::DumpSwitchActor(const SwitchActor *actor, std::ofstream &ofs) const {
  2781. MS_EXCEPTION_IF_NULL(actor);
  2782. ofs << "\tactor_name:" << actor->GetAID().Name() << '\n';
  2783. ofs << "\t\tactor input num:" << actor->input_nodes_.size() << "\n";
  2784. for (const auto &node : actor->input_nodes_) {
  2785. ofs << "\t\t\t" << AnfAlgo::GetNodeDebugString(node.first) << '\t' << node.second << '\n';
  2786. }
  2787. ofs << "\t\tactor input pos:\n";
  2788. for (size_t i = 0; i < actor->branch_inputs_pos_.size(); ++i) {
  2789. ofs << "\t\t\tbranch " << i << " input pos:";
  2790. for (const auto pos : actor->branch_inputs_pos_[i]) {
  2791. ofs << pos << '\t';
  2792. }
  2793. ofs << '\n';
  2794. }
  2795. ofs << "\t\tactor output data arrow:\n";
  2796. for (size_t i = 0; i < actor->output_branch_arrows_.size(); ++i) {
  2797. ofs << "\t\t\tbranch " << i << " output data:\n";
  2798. for (const auto arrow : actor->output_branch_arrows_[i]) {
  2799. MS_EXCEPTION_IF_NULL(arrow);
  2800. ofs << "\t\t\t\t from index:" << arrow->from_output_index_ << "\tto_actor_name:" << arrow->to_op_id_
  2801. << "\tto_input_index:" << arrow->to_input_index_ << '\n';
  2802. }
  2803. }
  2804. ofs << "\t\tactor output result arrow:\n";
  2805. for (size_t i = 0; i < actor->output_branch_result_arrows_.size(); ++i) {
  2806. ofs << "\t\t\tbranch " << i << " output result:\n";
  2807. for (const auto arrow : actor->output_branch_result_arrows_[i]) {
  2808. MS_EXCEPTION_IF_NULL(arrow);
  2809. ofs << "\t\t\t\t from index:" << arrow->from_output_index_ << "\tto_actor_name:" << arrow->to_op_id_
  2810. << "\tto_input_index:" << arrow->to_input_index_ << '\n';
  2811. }
  2812. }
  2813. ofs << "\t\tactor output control arrow:\n";
  2814. for (size_t i = 0; i < actor->output_branch_control_arrows_.size(); ++i) {
  2815. ofs << "\t\t\tbranch " << i << " output control:\n";
  2816. for (const auto arrow : actor->output_branch_control_arrows_[i]) {
  2817. ofs << "\t\t\t\t from index:" << arrow << '\n';
  2818. }
  2819. }
  2820. ofs << "\n";
  2821. }
  2822. } // namespace runtime
  2823. } // namespace mindspore