You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_scheduler.cc 46 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/framework/graph_scheduler.h"
  17. #include "runtime/framework/actor/memory_manager_actor.h"
  18. #include "mindrt/src/actor/actormgr.h"
  19. #include "mindrt/include/async/async.h"
  20. #include "backend/session/anf_runtime_algorithm.h"
  21. #include "backend/optimizer/common/helper.h"
  22. #include "utils/config_manager.h"
  23. #include "utils/log_adapter.h"
  24. #include "utils/convert_utils.h"
  25. #include "common/trans.h"
  26. namespace mindspore {
  27. namespace runtime {
  28. namespace {
  29. bool IsDeviceQueueDSActor(const AnfNodePtr &node) {
  30. MS_EXCEPTION_IF_NULL(node);
  31. if (node->isa<CNode>() && (AnfAlgo::GetCNodeName(node) == kGetNextOpName)) {
  32. return true;
  33. }
  34. return false;
  35. }
  36. bool IsHostQueueDSActor(const AnfNodePtr &node) {
  37. MS_EXCEPTION_IF_NULL(node);
  38. if (node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
  39. return true;
  40. }
  41. return false;
  42. }
  43. bool IsKernelActor(const AnfNodePtr &node) {
  44. MS_EXCEPTION_IF_NULL(node);
  45. if (node->isa<CNode>() && (AnfAlgo::GetCNodeName(node) != kGetNextOpName)) {
  46. return true;
  47. }
  48. return false;
  49. }
  50. // Judge whether the device tensor of the node is persistent or not.
  51. bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
  52. MS_EXCEPTION_IF_NULL(node);
  53. if (node->isa<ValueNode>()) {
  54. return true;
  55. }
  56. if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
  57. return true;
  58. }
  59. return false;
  60. }
  61. KernelActor *FindKernelActor(const KernelMapActor &kernel_actors_map, const std::string &name) {
  62. auto iter = kernel_actors_map.find(name);
  63. if (iter != kernel_actors_map.end()) {
  64. return iter->second.get();
  65. }
  66. return nullptr;
  67. }
  68. DeviceQueueDataSourceActor *FindDeviceQueueDSActor(const std::vector<DataSourceActorPtr> &data_source_actors) {
  69. for (auto &actor : data_source_actors) {
  70. MS_EXCEPTION_IF_NULL(actor);
  71. if (actor->GetAID().Name().find("_DeviceQueueDataSourceActor") != string::npos) {
  72. auto device_queue_ds_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor.get());
  73. return device_queue_ds_actor;
  74. }
  75. }
  76. return nullptr;
  77. }
  78. HostQueueDataSourceActor *FindHostQueueDSActor(const std::vector<DataSourceActorPtr> &data_source_actors) {
  79. for (auto &actor : data_source_actors) {
  80. MS_EXCEPTION_IF_NULL(actor);
  81. if (actor->GetAID().Name().find("_HostQueueDataSourceActor") != string::npos) {
  82. auto device_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(actor.get());
  83. return device_queue_ds_actor;
  84. }
  85. }
  86. return nullptr;
  87. }
  88. // Update the reference count of device tensor by the output index of node.
  89. void UpdateRefCount(const AnfNodePtr &node, size_t output_idx) {
  90. MS_EXCEPTION_IF_NULL(node);
  91. auto device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_idx);
  92. MS_EXCEPTION_IF_NULL(device_tensor);
  93. device_tensor->IncreaseOriginalRefCount();
  94. device_tensor->ResetRefCount();
  95. }
  96. // The branch processing of PrepareDataForValueNode that value type is tensor.
  97. void PrepareDataForValueNodeTensor(const ValueNodePtr &node, const ValuePtr &node_value,
  98. const DeviceContext *device_context) {
  99. MS_EXCEPTION_IF_NULL(node);
  100. MS_EXCEPTION_IF_NULL(node_value);
  101. MS_EXCEPTION_IF_NULL(device_context);
  102. std::vector<TensorPtr> tensors;
  103. TensorValueToTensor(node_value, &tensors);
  104. for (size_t i = 0; i < tensors.size(); i++) {
  105. const auto &tensor = tensors[i];
  106. if (tensor == nullptr) {
  107. MS_LOG(WARNING) << "Tensor is null";
  108. return;
  109. }
  110. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, i);
  111. MS_EXCEPTION_IF_NULL(device_tensor);
  112. // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
  113. if (device_tensor->GetPtr() != nullptr) {
  114. return;
  115. }
  116. MS_LOG(INFO) << "Prepare device data for value node: " << node->fullname_with_scope() << ", output index: " << i;
  117. tensor->set_device_address(device_tensor);
  118. // Allocate device memory.
  119. if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
  120. MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, node name: " << node->fullname_with_scope()
  121. << ", alloc size: " << device_tensor->GetSize();
  122. }
  123. // Copy data from host tensor to device.
  124. if (!device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0), LongToSize(tensor->data().nbytes()),
  125. tensor->data_type(), tensor->data_c())) {
  126. MS_LOG(EXCEPTION) << "SyncHostToDevice failed, node name: " << node->fullname_with_scope();
  127. }
  128. }
  129. }
  130. // Prepare the device data for persistent device tensor of value node.
  131. void PrepareDataForValueNode(const ValueNodePtr &node, const DeviceContext *device_context) {
  132. MS_EXCEPTION_IF_NULL(node);
  133. MS_EXCEPTION_IF_NULL(device_context);
  134. auto &node_value = node->value();
  135. MS_EXCEPTION_IF_NULL(node_value);
  136. if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
  137. // The branch processing that value type is tensor.
  138. PrepareDataForValueNodeTensor(node, node_value, device_context);
  139. } else if (node_value->isa<StringImm>()) {
  140. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, 0);
  141. MS_EXCEPTION_IF_NULL(device_tensor);
  142. // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
  143. if (device_tensor->GetPtr() != nullptr) {
  144. return;
  145. }
  146. MS_LOG(INFO) << "Prepare device data for value node: " << node->fullname_with_scope();
  147. // Allocate device memory.
  148. if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
  149. MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, node name: " << node->fullname_with_scope()
  150. << ", alloc size: " << device_tensor->GetSize();
  151. }
  152. // Copy data from value to device.
  153. auto value = GetValue<std::string>(node_value);
  154. size_t tensor_size = value.size();
  155. ShapeVector shape = {1, SizeToLong(tensor_size)};
  156. if (!device_tensor->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) {
  157. MS_LOG(EXCEPTION) << "SyncHostToDevice failed, node name: " << node->fullname_with_scope();
  158. }
  159. }
  160. }
  161. // Prepare the device data for persistent device tensor of weight node from host tensor.
  162. void PrepareDataForWeightNode(const AnfNodePtr &node, const TensorPtr &tensor, const DeviceContext *device_context) {
  163. MS_EXCEPTION_IF_NULL(node);
  164. MS_EXCEPTION_IF_NULL(tensor);
  165. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, 0);
  166. MS_EXCEPTION_IF_NULL(device_tensor);
  167. const auto &host_tensor_address = std::dynamic_pointer_cast<DeviceTensor>(tensor->device_address());
  168. // If the host tensor has the device address, it indicates that the device address of host tensor is new.
  169. if (host_tensor_address != nullptr) {
  170. if (host_tensor_address != device_tensor) {
  171. AnfAlgo::SetOutputAddr(host_tensor_address, 0, node.get());
  172. DeviceTensorStore::GetInstance().Insert(node.get(), host_tensor_address);
  173. }
  174. return;
  175. }
  176. // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
  177. if (device_tensor->GetPtr() != nullptr) {
  178. return;
  179. }
  180. MS_LOG(INFO) << "Prepare device data for weight node: " << node->fullname_with_scope();
  181. tensor->set_device_address(device_tensor);
  182. // Allocate device memory.
  183. if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
  184. MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, node name: " << node->fullname_with_scope()
  185. << ", alloc size: " << device_tensor->GetSize();
  186. }
  187. // Copy data from host tensor to device.
  188. if (!device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0), LongToSize(tensor->data().nbytes()),
  189. tensor->data_type(), tensor->data_c())) {
  190. MS_LOG(EXCEPTION) << "SyncHostToDevice failed, node name: " << node->fullname_with_scope();
  191. }
  192. }
  193. BaseRef CreateOutputTensor(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
  194. const std::vector<tensor::TensorPtr> &input_tensors) {
  195. auto &node = node_output_pair.first;
  196. auto output_index = node_output_pair.second;
  197. MS_EXCEPTION_IF_NULL(node);
  198. if (node->isa<ValueNode>()) {
  199. // If node is a value node, return the value.
  200. auto value_node = node->cast<ValueNodePtr>();
  201. MS_EXCEPTION_IF_NULL(value_node);
  202. return value_node->value();
  203. } else if (node->isa<Parameter>()) {
  204. // If node is a parameter node, return tensor from input_tensors.
  205. MS_EXCEPTION_IF_NULL(graph);
  206. const auto &input_nodes = graph->inputs();
  207. auto iter = find(input_nodes.begin(), input_nodes.end(), node);
  208. if (iter == input_nodes.end()) {
  209. MS_LOG(EXCEPTION) << "Parameter node: " << node->fullname_with_scope() << " is not exist.";
  210. }
  211. auto position = IntToSize(std::distance(input_nodes.begin(), iter));
  212. return input_tensors[position];
  213. } else {
  214. // Create tensor.
  215. TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
  216. if (type_id == kTypeUnknown) {
  217. type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
  218. }
  219. std::vector<int64_t> temp_shape;
  220. auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
  221. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  222. auto tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  223. MS_EXCEPTION_IF_NULL(tensor);
  224. tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
  225. // Set device address to tensor.
  226. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_index);
  227. MS_EXCEPTION_IF_NULL(device_tensor);
  228. tensor->set_device_address(device_tensor);
  229. device_tensor->set_original_ref_count(SIZE_MAX);
  230. device_tensor->ResetRefCount();
  231. return tensor;
  232. }
  233. }
  234. BaseRef CreateOutputTensors(const AnfNodePtr &output_node, const KernelGraphPtr &graph,
  235. const std::vector<tensor::TensorPtr> &input_tensors) {
  236. MS_EXCEPTION_IF_NULL(output_node);
  237. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(output_node, 0);
  238. MS_EXCEPTION_IF_NULL(item_with_index.first);
  239. // Special handle for make tuple.
  240. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  241. auto cnode = item_with_index.first->cast<CNodePtr>();
  242. MS_EXCEPTION_IF_NULL(cnode);
  243. VectorRef ret;
  244. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  245. auto out = CreateOutputTensors(cnode->input(i), graph, input_tensors);
  246. ret.push_back(out);
  247. }
  248. return ret;
  249. }
  250. // If the node return nothing, return an empty vectorRef.
  251. if (AnfAlgo::GetOutputTensorNum(item_with_index.first) == 0) {
  252. return VectorRef();
  253. }
  254. return CreateOutputTensor(item_with_index, graph, input_tensors);
  255. }
  256. void AllocateContinuousMemoryForInput(const AnfNodePtr &kernel, const DeviceContext *device_context,
  257. bool is_all_nop_node) {
  258. MS_EXCEPTION_IF_NULL(kernel);
  259. MS_EXCEPTION_IF_NULL(device_context);
  260. bool is_need_alloc_memory = false;
  261. size_t total_size = 0;
  262. std::vector<size_t> size_list;
  263. std::vector<DeviceTensorPtr> addr_list;
  264. const auto &kernel_mod = AnfAlgo::GetKernelMod(kernel);
  265. MS_EXCEPTION_IF_NULL(kernel_mod);
  266. const auto &intput_sizes = kernel_mod->GetInputSizeList();
  267. for (size_t i = 0; i < intput_sizes.size(); ++i) {
  268. DeviceTensorPtr device_tensor;
  269. if (is_all_nop_node) {
  270. // Graph may be all nop nodes and not remove nop node, so this can not skip nop node.
  271. device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false);
  272. } else {
  273. device_tensor = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, true);
  274. }
  275. MS_EXCEPTION_IF_NULL(device_tensor);
  276. // In the scene of communication op and computing op parallel multi stream, the input address of communication op
  277. // can't be reused, so set the max reference count.
  278. device_tensor->set_original_ref_count(SIZE_MAX);
  279. device_tensor->ResetRefCount();
  280. if (device_tensor->GetPtr() == nullptr) {
  281. is_need_alloc_memory = true;
  282. }
  283. total_size += intput_sizes[i];
  284. size_list.emplace_back(intput_sizes[i]);
  285. addr_list.emplace_back(device_tensor);
  286. }
  287. if (is_need_alloc_memory) {
  288. auto ret = device_context->AllocateContinuousMemory(addr_list, total_size, size_list);
  289. if (!ret) {
  290. MS_LOG(EXCEPTION) << "Malloc device memory failed.";
  291. }
  292. }
  293. }
  294. void AllocateContinuousMemoryForOutput(const AnfNodePtr &kernel, const DeviceContext *device_context) {
  295. MS_EXCEPTION_IF_NULL(kernel);
  296. MS_EXCEPTION_IF_NULL(device_context);
  297. bool is_need_alloc_memory = false;
  298. size_t total_size = 0;
  299. std::vector<size_t> size_list;
  300. std::vector<DeviceTensorPtr> addr_list;
  301. const auto &kernel_mod = AnfAlgo::GetKernelMod(kernel);
  302. MS_EXCEPTION_IF_NULL(kernel_mod);
  303. const auto &output_sizes = kernel_mod->GetOutputSizeList();
  304. for (size_t i = 0; i < output_sizes.size(); ++i) {
  305. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
  306. MS_EXCEPTION_IF_NULL(device_tensor);
  307. // One time application for continuous memory, so set the max reference count.
  308. device_tensor->set_original_ref_count(SIZE_MAX);
  309. device_tensor->ResetRefCount();
  310. if (device_tensor->GetPtr() == nullptr) {
  311. is_need_alloc_memory = true;
  312. }
  313. total_size += output_sizes[i];
  314. size_list.emplace_back(output_sizes[i]);
  315. addr_list.emplace_back(device_tensor);
  316. }
  317. if (is_need_alloc_memory) {
  318. auto ret = device_context->AllocateContinuousMemory(addr_list, total_size, size_list);
  319. if (!ret) {
  320. MS_LOG(EXCEPTION) << "Malloc device memory failed.";
  321. }
  322. }
  323. }
  324. } // namespace
  325. void GraphScheduler::Initialize() {
  326. if (init_) {
  327. return;
  328. }
  329. init_ = true;
  330. auto actorMgr = ActorMgr::GetActorMgrRef();
  331. MS_EXCEPTION_IF_NULL(actorMgr);
  332. // Create the thread pool of actor runtime.
  333. auto max_thread_num = GetMaxThreadNum();
  334. MS_LOG(INFO) << "Max available thread number: " << max_thread_num;
  335. actorMgr->Initialize(max_thread_num);
  336. // Create memory manager actor.
  337. auto memory_manager_actor = std::make_shared<MemoryManagerActor>();
  338. MS_EXCEPTION_IF_NULL(memory_manager_actor);
  339. memory_manager_aid_ = memory_manager_actor->GetAID();
  340. // Schedule memory manager actor, bind single thread to response to memory alloc and free quickly.
  341. auto base_actor = static_cast<ActorReference>(memory_manager_actor);
  342. (void)actorMgr->Spawn(base_actor, false);
  343. }
  344. ActorSet *GraphScheduler::Transform(const std::vector<KernelGraphPtr> &graphs,
  345. const std::vector<DeviceContext *> &device_contexts,
  346. const std::vector<TensorPtr> *input_tensors,
  347. const std::vector<AnfNodePtr> *control_nodes, GraphExecutionStrategy strategy) {
  348. if (graphs.size() != device_contexts.size()) {
  349. MS_LOG(EXCEPTION) << "The number of graphs is not equal to the number of device_contexts.";
  350. }
  351. Initialize();
  352. std::vector<ActorSetPtr> actor_sets;
  353. for (size_t i = 0; i < graphs.size(); ++i) {
  354. auto graph = graphs[i];
  355. auto device_context = device_contexts[i];
  356. MS_EXCEPTION_IF_NULL(graph);
  357. MS_LOG(INFO) << "Graph(" << graph->ToString() << ") transforms actor begin.";
  358. PersistDeviceTensor(graph);
  359. auto actor_set = Build(graph, device_context);
  360. actor_sets.emplace_back(actor_set);
  361. graph_to_actors_.emplace(graph, actor_set);
  362. Link(actor_set.get(), graph, strategy);
  363. if (!CheckActorValid(actor_set.get())) {
  364. MS_LOG(EXCEPTION) << "The actor set of " << graph->ToString() << " is invalid.";
  365. }
  366. MS_LOG(INFO) << "Graph(" << graph->ToString() << ") transforms actor end.";
  367. }
  368. return actor_sets[0].get();
  369. }
  370. void GraphScheduler::Schedule(const ActorSet *actor_set) {
  371. MS_EXCEPTION_IF_NULL(actor_set);
  372. auto actorMgr = ActorMgr::GetActorMgrRef();
  373. MS_EXCEPTION_IF_NULL(actorMgr);
  374. // Schedule dats source actors.
  375. for (auto &data_source_actor : actor_set->data_source_actors_) {
  376. MS_EXCEPTION_IF_NULL(data_source_actor);
  377. auto base_actor = static_cast<ActorReference>(data_source_actor);
  378. (void)actorMgr->Spawn(base_actor);
  379. }
  380. // Schedule kernel actors.
  381. for (auto &kernel_actor : actor_set->kernel_actors_) {
  382. MS_EXCEPTION_IF_NULL(kernel_actor);
  383. auto base_actor = static_cast<ActorReference>(kernel_actor);
  384. (void)actorMgr->Spawn(base_actor);
  385. }
  386. // Schedule loop count actor.
  387. if (actor_set->loop_count_actor_ != nullptr) {
  388. auto base_actor = static_cast<ActorReference>(actor_set->loop_count_actor_);
  389. (void)actorMgr->Spawn(base_actor);
  390. }
  391. }
  392. void GraphScheduler::PrepareRun(const KernelGraphPtr &graph, const std::vector<TensorPtr> *input_tensors,
  393. VectorRef *const &outputs) {
  394. MS_EXCEPTION_IF_NULL(graph);
  395. MS_EXCEPTION_IF_NULL(input_tensors);
  396. MS_EXCEPTION_IF_NULL(outputs);
  397. // Get the device context for the first kernel actor.
  398. const auto &actor_set = Fetch(graph);
  399. MS_EXCEPTION_IF_NULL(actor_set);
  400. const auto &first_kernel_actor = actor_set->kernel_actors_[0];
  401. MS_EXCEPTION_IF_NULL(first_kernel_actor);
  402. const auto &device_context = first_kernel_actor->device_context_;
  403. // 1.Prepare the data of device tensor store(value nodes of graph).
  404. for (const auto &value_node : graph->graph_value_nodes()) {
  405. if (AnfAlgo::OutputAddrExist(value_node, 0)) {
  406. PrepareDataForValueNode(value_node, device_context);
  407. }
  408. }
  409. // 1.Prepare the data of device tensor store(weights of graph), and fill the host tensors for non weighted parameters.
  410. std::vector<TensorPtr> host_tensors;
  411. const auto &input_nodes = graph->input_nodes();
  412. for (size_t i = 0; i < input_nodes.size(); ++i) {
  413. const auto &input_node = input_nodes[i];
  414. const auto &input_tensor = (*input_tensors)[i];
  415. MS_EXCEPTION_IF_NULL(input_node);
  416. if (IsPersistentDeviceTensor(input_node)) {
  417. // Prepare the device data for weights.
  418. PrepareDataForWeightNode(input_node, input_tensor, device_context);
  419. } else {
  420. // Fill the host tensors for non weighted parameters.
  421. host_tensors.emplace_back(input_tensor);
  422. }
  423. }
  424. // 2.Prepare the data of host tensor queue(non weighted parameters of graph).
  425. const auto &host_tensor_queue = FetchHostQueue(graph);
  426. if (host_tensor_queue != nullptr) {
  427. host_tensor_queue->PushData(host_tensors);
  428. }
  429. // 3.Prepare the output tensor of graph.
  430. for (const auto &output_node : graph->outputs()) {
  431. MS_EXCEPTION_IF_NULL(output_node);
  432. MS_LOG(INFO) << "Create node output: " << output_node->fullname_with_scope();
  433. outputs->emplace_back(CreateOutputTensors(output_node, graph, *input_tensors));
  434. }
  435. // 4.Prepare the continuous memory for communication kernel.
  436. for (const auto &kernel : graph->execution_order()) {
  437. if (AnfAlgo::IsCommunicationOp(kernel)) {
  438. AllocateContinuousMemoryForInput(kernel, device_context, graph->is_all_nop_node());
  439. AllocateContinuousMemoryForOutput(kernel, device_context);
  440. }
  441. }
  442. }
  443. bool GraphScheduler::Run(const ActorSet *actor_set, GraphExecutionStrategy strategy) {
  444. MS_EXCEPTION_IF_NULL(actor_set);
  445. // Construct OpContext.
  446. OpContext<DeviceTensor> op_context;
  447. uuids::uuid sequential_num;
  448. std::vector<Promise<int>> result(1);
  449. op_context.sequential_num_ = &sequential_num;
  450. op_context.results_ = &result;
  451. // Trigger no input kernel actor running.
  452. for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
  453. MS_EXCEPTION_IF_NULL(no_input_kernel_actor);
  454. Async(no_input_kernel_actor->GetAID(), &KernelActor::RunOpControl, nullptr, &op_context);
  455. }
  456. // Trigger data source actor running.
  457. for (auto &data_source_actor : actor_set->data_source_actors_) {
  458. MS_EXCEPTION_IF_NULL(data_source_actor);
  459. Async(data_source_actor->GetAID(), &DataSourceActor::FetchData, &op_context);
  460. }
  461. // Trigger kernel actor running in the step execution strategy.
  462. if (strategy == GraphExecutionStrategy::kStep) {
  463. for (auto &kernel_actor : actor_set->kernel_actors_) {
  464. MS_EXCEPTION_IF_NULL(kernel_actor);
  465. Async(kernel_actor->GetAID(), &KernelActor::RunOpControl, nullptr, &op_context);
  466. }
  467. }
  468. // Get the run result.
  469. auto result_future = result[0].GetFuture();
  470. result_future.Wait();
  471. if (!result_future.IsOK()) {
  472. return false;
  473. }
  474. // Sync device stream.
  475. const auto &first_kernel_actor = actor_set->kernel_actors_[0];
  476. MS_EXCEPTION_IF_NULL(first_kernel_actor);
  477. const auto &device_context = first_kernel_actor->device_context_;
  478. MS_EXCEPTION_IF_NULL(device_context);
  479. if (!device_context->SyncStream()) {
  480. MS_LOG(ERROR) << "Sync stream failed.";
  481. return false;
  482. }
  483. return true;
  484. }
  485. ActorSet *GraphScheduler::Fetch(const KernelGraphPtr &graph) const {
  486. MS_EXCEPTION_IF_NULL(graph);
  487. auto iter = graph_to_actors_.find(graph);
  488. if (iter != graph_to_actors_.end()) {
  489. return iter->second.get();
  490. } else {
  491. MS_LOG(ERROR) << "Can't find the actors map of graph: " << graph->ToString();
  492. return nullptr;
  493. }
  494. }
  495. ActorSetPtr GraphScheduler::Build(const KernelGraphPtr &graph, const DeviceContext *device_context) {
  496. auto actor_set = std::make_shared<ActorSet>();
  497. MS_EXCEPTION_IF_NULL(actor_set);
  498. auto data_source_actors = BuildDataSourceActor(graph, device_context);
  499. actor_set->data_source_actors_.swap(data_source_actors);
  500. auto kernel_actors = BuildKernelActor(graph, device_context);
  501. actor_set->kernel_actors_.swap(kernel_actors);
  502. auto loop_count_actor = BuildLoopCountActor(graph);
  503. actor_set->loop_count_actor_ = loop_count_actor;
  504. return actor_set;
  505. }
  506. void GraphScheduler::Link(ActorSet *actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy) {
  507. MS_EXCEPTION_IF_NULL(actor_set);
  508. MS_EXCEPTION_IF_NULL(graph);
  509. KernelMapActor kernel_actors_temp_map;
  510. for (auto &actor : actor_set->kernel_actors_) {
  511. MS_EXCEPTION_IF_NULL(actor);
  512. kernel_actors_temp_map.emplace(actor->GetAID().Name(), actor);
  513. }
  514. // Foreach the execution order to link the actors.
  515. auto execution_order = graph->execution_order();
  516. for (auto &kernel : execution_order) {
  517. if (!IsKernelActor(kernel)) {
  518. continue;
  519. }
  520. auto kernel_actor = FindKernelActor(kernel_actors_temp_map, kernel->fullname_with_scope());
  521. // Link the control arrows of kernel actor.
  522. LinkControlArrowForKernelActor(kernel_actor, actor_set->loop_count_actor_.get(), graph, strategy);
  523. for (size_t i = 0; i < AnfAlgo::GetInputNum(kernel); ++i) {
  524. auto input_node = AnfAlgo::GetInputNode(kernel, i);
  525. // Link the control arrows of kernel actor by the auto monad, the inputs include monad node.
  526. LinkControlArrowByAutoMonad(kernel_actor, input_node, kernel_actors_temp_map);
  527. if (HasAbstractMonad(input_node)) {
  528. continue; // No data arrow for monad input.
  529. }
  530. KernelWithIndex from_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true);
  531. KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);
  532. auto from_kernel = from_kernel_with_output_idx.first;
  533. if (IsDeviceQueueDSActor(from_kernel)) {
  534. // Link the data arrows of device queue data source actor.
  535. auto from_actor = FindDeviceQueueDSActor(actor_set->data_source_actors_);
  536. LinkDataArrowForDeviceDSActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
  537. } else if (IsHostQueueDSActor(from_kernel)) {
  538. // Link the data arrows of host queue data source actor.
  539. auto from_actor = FindHostQueueDSActor(actor_set->data_source_actors_);
  540. LinkDataArrowForHostDSActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
  541. } else {
  542. // Link the data arrows of kernel actor.
  543. auto from_actor = FindKernelActor(kernel_actors_temp_map, from_kernel->fullname_with_scope());
  544. LinkDataArrowForKernelActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
  545. }
  546. }
  547. }
  548. // BuildNoInputKernelActor depends on whether kernel actors have input, so must be behind the link of kernel actors.
  549. auto no_input_kernel_actors = BuildNoInputKernelActor(graph);
  550. actor_set->no_input_kernel_actors_.swap(no_input_kernel_actors);
  551. // Link the control arrows of loop count actor, which depends on the no input kernel actors.
  552. LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), graph);
  553. }
  554. std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const KernelGraphPtr &graph,
  555. const DeviceContext *device_context) {
  556. MS_EXCEPTION_IF_NULL(graph);
  557. std::vector<DataSourceActorPtr> data_source_actors;
  558. // Build host queue data source actor.
  559. HostQueueDSActorPtr host_queue_ds_actor = nullptr;
  560. for (auto &input_node : graph->input_nodes()) {
  561. MS_EXCEPTION_IF_NULL(input_node);
  562. if (IsHostQueueDSActor(input_node)) {
  563. if (host_queue_ds_actor == nullptr) {
  564. auto actor_name = graph->ToString() + "_" + "HostQueueDataSourceActor";
  565. MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
  566. auto host_queue = std::make_shared<HostTensorQueue>();
  567. graph_to_host_queue_.emplace(graph, host_queue);
  568. host_queue_ds_actor =
  569. std::make_shared<HostQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_, host_queue);
  570. data_source_actors.emplace_back(host_queue_ds_actor);
  571. }
  572. host_queue_ds_actor->data_nodes_.emplace_back(input_node);
  573. }
  574. }
  575. // Build device queue data source actor.
  576. auto execution_order = graph->execution_order();
  577. auto iter = std::find_if(execution_order.begin(), execution_order.end(),
  578. [](const CNodePtr &node) { return IsDeviceQueueDSActor(node); });
  579. if (iter != execution_order.end()) {
  580. auto actor_name = graph->ToString() + "_" + "DeviceQueueDataSourceActor";
  581. MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
  582. auto device_queue_ds_actor =
  583. std::make_shared<DeviceQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_);
  584. MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
  585. data_source_actors.emplace_back(device_queue_ds_actor);
  586. device_queue_ds_actor->data_kernel_ = *iter;
  587. }
  588. return data_source_actors;
  589. }
  590. std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const KernelGraphPtr &graph,
  591. const DeviceContext *device_context) {
  592. MS_EXCEPTION_IF_NULL(graph);
  593. std::vector<KernelActorPtr> kernel_actors;
  594. auto execution_order = graph->execution_order();
  595. for (auto &kernel : execution_order) {
  596. if (IsKernelActor(kernel)) {
  597. auto kernel_actor =
  598. std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_);
  599. MS_EXCEPTION_IF_NULL(kernel_actor);
  600. kernel_actors.emplace_back(kernel_actor);
  601. }
  602. }
  603. return kernel_actors;
  604. }
  605. std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const KernelGraphPtr &graph) {
  606. MS_EXCEPTION_IF_NULL(graph);
  607. std::vector<KernelActorPtr> no_input_kernel_actors;
  608. auto actor_set = Fetch(graph);
  609. MS_EXCEPTION_IF_NULL(actor_set);
  610. for (auto &kernel_actor : actor_set->kernel_actors_) {
  611. MS_EXCEPTION_IF_NULL(kernel_actor);
  612. if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) {
  613. no_input_kernel_actors.emplace_back(kernel_actor);
  614. // The no input kernel actor will be triggered by loop count actor, so need set the input_controls_num_.
  615. kernel_actor->input_controls_num_ = 1;
  616. }
  617. }
  618. return no_input_kernel_actors;
  619. }
  620. LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const KernelGraphPtr &graph) {
  621. MS_EXCEPTION_IF_NULL(graph);
  622. auto loop_count = ConfigManager::GetInstance().iter_num();
  623. auto actor_name = graph->ToString() + "_" + "LoopCountActor";
  624. auto loop_count_actor = std::make_shared<LoopCountActor>(actor_name, loop_count);
  625. MS_LOG(INFO) << "Create loop count actor: " << actor_name;
  626. MS_EXCEPTION_IF_NULL(loop_count_actor);
  627. return loop_count_actor;
  628. }
  629. void GraphScheduler::LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor,
  630. KernelWithIndex from_kernel_with_output_idx,
  631. KernelWithIndex to_kernel_with_input_idx) {
  632. MS_EXCEPTION_IF_NULL(from_actor);
  633. MS_EXCEPTION_IF_NULL(to_actor);
  634. auto from_kernel = from_kernel_with_output_idx.first;
  635. MS_EXCEPTION_IF_NULL(from_kernel);
  636. auto from_output_index = from_kernel_with_output_idx.second;
  637. auto to_input_index = to_kernel_with_input_idx.second;
  638. auto to_aid = to_actor->GetAID();
  639. auto op_arrow = std::make_shared<OpArrow>(from_output_index, to_aid, to_input_index);
  640. from_actor->output_op_arrows_.emplace_back(op_arrow);
  641. to_actor->input_datas_num_++;
  642. // Update the reference count of device tensor.
  643. UpdateRefCount(from_kernel, from_output_index);
  644. }
  645. void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_actor, KernelActor *to_actor,
  646. KernelWithIndex from_kernel_with_output_idx,
  647. KernelWithIndex to_kernel_with_input_idx) {
  648. MS_EXCEPTION_IF_NULL(from_actor);
  649. MS_EXCEPTION_IF_NULL(to_actor);
  650. auto from_kernel = from_kernel_with_output_idx.first;
  651. MS_EXCEPTION_IF_NULL(from_kernel);
  652. auto from_output_index = from_kernel_with_output_idx.second;
  653. auto to_input_index = to_kernel_with_input_idx.second;
  654. auto data_nodes = from_actor->data_nodes_;
  655. auto iter = find(data_nodes.begin(), data_nodes.end(), from_kernel);
  656. if (iter == data_nodes.end()) {
  657. MS_LOG(EXCEPTION) << "Parameter node: " << from_kernel->fullname_with_scope() << " is not exist.";
  658. }
  659. auto position = IntToSize(std::distance(data_nodes.begin(), iter));
  660. auto to_aid = to_actor->GetAID();
  661. auto op_arrow = std::make_shared<OpArrow>(position, to_aid, to_input_index);
  662. from_actor->output_op_arrows_.emplace_back(op_arrow);
  663. to_actor->input_datas_num_++;
  664. // Update the reference count of device tensor.
  665. UpdateRefCount(from_kernel, from_output_index);
  666. }
  667. void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
  668. KernelWithIndex from_kernel_with_output_idx,
  669. KernelWithIndex to_kernel_with_input_idx) {
  670. MS_EXCEPTION_IF_NULL(to_actor);
  671. auto from_kernel = from_kernel_with_output_idx.first;
  672. MS_EXCEPTION_IF_NULL(from_kernel);
  673. auto from_output_index = from_kernel_with_output_idx.second;
  674. auto to_input_index = to_kernel_with_input_idx.second;
  675. if (IsPersistentDeviceTensor(from_kernel)) {
  676. to_actor->device_tensor_store_keys_.emplace_back(to_input_index, static_cast<void *>(from_kernel.get()));
  677. } else if (IsKernelActor(from_kernel)) {
  678. MS_EXCEPTION_IF_NULL(from_actor);
  679. auto to_aid = to_actor->GetAID();
  680. auto op_arrow = std::make_shared<OpArrow>(from_output_index, to_aid, to_input_index);
  681. from_actor->output_op_arrows_.emplace_back(op_arrow);
  682. to_actor->input_datas_num_++;
  683. // Update the reference count of device tensor.
  684. UpdateRefCount(from_kernel, from_output_index);
  685. }
  686. }
  687. void GraphScheduler::LinkControlArrowForKernelActor(KernelActor *from_actor, LoopCountActor *to_actor,
  688. const KernelGraphPtr &graph, GraphExecutionStrategy strategy) {
  689. MS_EXCEPTION_IF_NULL(from_actor);
  690. MS_EXCEPTION_IF_NULL(to_actor);
  691. MS_EXCEPTION_IF_NULL(graph);
  692. if (strategy == GraphExecutionStrategy::kStep) {
  693. from_actor->input_controls_num_++;
  694. }
  695. // The manager of graph member is weak ptr, so need created and used in the function IsNotRealUsedByOthers.
  696. const auto &manager = Manage(graph, true);
  697. MS_EXCEPTION_IF_NULL(manager);
  698. if (opt::IsNotRealUsedByOthers(graph, from_actor->kernel_)) {
  699. MS_EXCEPTION_IF_NULL(from_actor->kernel_);
  700. MS_LOG(INFO) << from_actor->kernel_->fullname_with_scope() << " is not real used by other nodes.";
  701. auto to_aid = to_actor->GetAID();
  702. from_actor->output_op_controls_.emplace_back(to_aid);
  703. to_actor->input_controls_num_++;
  704. }
  705. }
  706. void GraphScheduler::LinkControlArrowByAutoMonad(KernelActor *to_actor, const AnfNodePtr &from_node,
  707. const KernelMapActor &kernel_actors_map) {
  708. MS_EXCEPTION_IF_NULL(to_actor);
  709. MS_EXCEPTION_IF_NULL(from_node);
  710. if (!from_node->isa<CNode>()) {
  711. return;
  712. }
  713. // Find the real input node, include the monad node and make tuple node.
  714. const std::vector<PrimitivePtr> &return_types = {prim::kPrimUpdateState, prim::kPrimLoad, prim::kPrimMakeTuple};
  715. const auto &input_kernel_with_output_idx = AnfAlgo::VisitKernelWithReturnType(from_node, 0, true, return_types);
  716. MS_EXCEPTION_IF_NULL(input_kernel_with_output_idx.first);
  717. if (!input_kernel_with_output_idx.first->isa<CNode>()) {
  718. return;
  719. }
  720. const auto &input_cnode = input_kernel_with_output_idx.first->cast<CNodePtr>();
  721. MS_EXCEPTION_IF_NULL(input_cnode);
  722. // Get the real depend input by monad node which needs to link the control arrow.
  723. AnfNodePtr real_depend_input = nullptr;
  724. if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimUpdateState)) {
  725. real_depend_input = input_cnode->input(kUpdateStateRealInput);
  726. } else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimLoad)) {
  727. real_depend_input = input_cnode->input(kLoadStateInput);
  728. } else if (AnfAlgo::CheckPrimitiveType(input_cnode, prim::kPrimMakeTuple)) {
  729. // Make tuple node needs to be expanded.
  730. for (size_t i = 1; i < input_cnode->inputs().size(); ++i) {
  731. LinkControlArrowByAutoMonad(to_actor, input_cnode->input(i), kernel_actors_map);
  732. }
  733. return;
  734. } else {
  735. return;
  736. }
  737. MS_EXCEPTION_IF_NULL(real_depend_input);
  738. if (!real_depend_input->isa<CNode>()) {
  739. return;
  740. }
  741. // The monad node and make tuple node need recursion.
  742. if (AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimUpdateState) ||
  743. AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimLoad) ||
  744. AnfAlgo::CheckPrimitiveType(real_depend_input, prim::kPrimMakeTuple)) {
  745. LinkControlArrowByAutoMonad(to_actor, real_depend_input, kernel_actors_map);
  746. return;
  747. }
  748. // Link the control arrow between the kernel actors.
  749. auto from_actor = FindKernelActor(kernel_actors_map, real_depend_input->fullname_with_scope());
  750. MS_EXCEPTION_IF_NULL(from_actor);
  751. from_actor->output_op_controls_.emplace_back(to_actor->GetAID());
  752. to_actor->input_controls_num_++;
  753. }
  754. void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const KernelGraphPtr &graph) {
  755. MS_EXCEPTION_IF_NULL(graph);
  756. MS_EXCEPTION_IF_NULL(loop_count_actor);
  757. auto actor_set = Fetch(graph);
  758. MS_EXCEPTION_IF_NULL(actor_set);
  759. // Set the source data actor.
  760. for (auto &data_source_actor : actor_set->data_source_actors_) {
  761. MS_EXCEPTION_IF_NULL(data_source_actor);
  762. loop_count_actor->data_source_aids_.emplace_back(data_source_actor->GetAID());
  763. }
  764. // Set the no input kernel actor.
  765. for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
  766. MS_EXCEPTION_IF_NULL(no_input_kernel_actor);
  767. loop_count_actor->no_input_kernel_aids_.emplace_back(no_input_kernel_actor->GetAID());
  768. }
  769. }
  770. bool GraphScheduler::CheckActorValid(const ActorSet *actor_set) const {
  771. MS_EXCEPTION_IF_NULL(actor_set);
  772. // Check the data source actors.
  773. for (const auto &data_source_actor : actor_set->data_source_actors_) {
  774. MS_EXCEPTION_IF_NULL(data_source_actor);
  775. if (data_source_actor->output_op_arrows_.size() == 0) {
  776. MS_LOG(ERROR) << data_source_actor->GetAID().Name() << " has no user.";
  777. return false;
  778. }
  779. }
  780. // Check the kernel actors.
  781. for (const auto &kernel_actor : actor_set->kernel_actors_) {
  782. MS_EXCEPTION_IF_NULL(kernel_actor);
  783. if (kernel_actor->output_op_arrows_.size() + kernel_actor->output_op_controls_.size() == 0) {
  784. MS_LOG(ERROR) << kernel_actor->GetAID().Name() << " has no user.";
  785. return false;
  786. }
  787. auto input_num = AnfAlgo::GetInputTensorNum(kernel_actor->kernel_);
  788. auto input_data_num = kernel_actor->input_datas_num_;
  789. auto device_tensor_store_num = kernel_actor->device_tensor_store_keys_.size();
  790. if (input_data_num + device_tensor_store_num != input_num) {
  791. MS_LOG(ERROR) << "The input building of " << kernel_actor->GetAID().Name()
  792. << " is wrong, input data num: " << input_data_num
  793. << ", device tensor store num: " << device_tensor_store_num << ", total input num: " << input_num;
  794. return false;
  795. }
  796. }
  797. // Check the loop count actor.
  798. const auto &loop_count_actor = actor_set->loop_count_actor_;
  799. if (loop_count_actor != nullptr) {
  800. if (loop_count_actor->input_controls_num_ == 0) {
  801. MS_LOG(ERROR) << loop_count_actor->GetAID().Name() << " has no source.";
  802. return false;
  803. }
  804. }
  805. return true;
  806. }
  807. void GraphScheduler::PersistDeviceTensor(const KernelGraphPtr &graph) {
  808. MS_EXCEPTION_IF_NULL(graph);
  809. for (auto &value_node : graph->graph_value_nodes()) {
  810. MS_EXCEPTION_IF_NULL(value_node);
  811. if (!AnfAlgo::OutputAddrExist(value_node, 0)) {
  812. MS_LOG(INFO) << "The device address is not exist: " << value_node->ToString();
  813. continue;
  814. }
  815. auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0);
  816. DeviceTensorStore::GetInstance().Insert(value_node.get(), device_tensor);
  817. device_tensor->set_original_ref_count(SIZE_MAX);
  818. device_tensor->ResetRefCount();
  819. }
  820. for (auto &input_node : graph->input_nodes()) {
  821. MS_EXCEPTION_IF_NULL(input_node);
  822. if (IsPersistentDeviceTensor(input_node)) {
  823. auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0);
  824. MS_EXCEPTION_IF_NULL(device_tensor);
  825. DeviceTensorStore::GetInstance().Insert(input_node.get(), device_tensor);
  826. device_tensor->set_original_ref_count(SIZE_MAX);
  827. device_tensor->ResetRefCount();
  828. }
  829. }
  830. }
  831. HostTensorQueue *GraphScheduler::FetchHostQueue(const KernelGraphPtr &graph) const {
  832. MS_EXCEPTION_IF_NULL(graph);
  833. const auto &iter = graph_to_host_queue_.find(graph);
  834. if (iter != graph_to_host_queue_.end()) {
  835. return iter->second.get();
  836. } else {
  837. return nullptr;
  838. }
  839. }
  840. void GraphScheduler::DumpActor(const KernelGraphPtr &graph) const {
  841. MS_EXCEPTION_IF_NULL(graph);
  842. const auto &actor_set = Fetch(graph);
  843. MS_EXCEPTION_IF_NULL(actor_set);
  844. std::string filename = "./actor_set_" + graph->ToString() + ".ir";
  845. std::ofstream ofs(filename);
  846. if (!ofs.is_open()) {
  847. MS_LOG(ERROR) << "Open file [" << filename << "] failed!";
  848. return;
  849. }
  850. ofs << "[Data source actors]\n";
  851. for (const auto &data_source_actor : actor_set->data_source_actors_) {
  852. DumpDSActor(data_source_actor.get(), ofs);
  853. ofs << "\n";
  854. }
  855. ofs << "\n[Kernel actors]\n";
  856. for (const auto &kernel_actor : actor_set->kernel_actors_) {
  857. DumpKernelActor(kernel_actor.get(), ofs);
  858. ofs << "\n";
  859. }
  860. ofs << "\n[No input kernel actors]\n";
  861. for (const auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
  862. DumpKernelActor(no_input_kernel_actor.get(), ofs);
  863. ofs << "\n";
  864. }
  865. ofs << "\n[Loop count actor]\n";
  866. const auto &loop_count_actor = actor_set->loop_count_actor_;
  867. if (loop_count_actor != nullptr) {
  868. DumpLoopCountActor(loop_count_actor.get(), ofs);
  869. ofs << "\n";
  870. }
  871. }
  872. void GraphScheduler::DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) const {
  873. MS_EXCEPTION_IF_NULL(actor);
  874. const auto &actor_name = actor->GetAID().Name();
  875. MS_EXCEPTION_IF_NULL(actor->device_context_);
  876. ofs << "\tactor_name:" << actor_name << "\tdevice_context:" << actor->device_context_->device_context_key().ToString()
  877. << "\n";
  878. if (actor_name.find("_DeviceQueueDataSourceActor") != string::npos) {
  879. // Dump the member info of device queue data source actor.
  880. const auto &device_queue_ds_actor = dynamic_cast<const DeviceQueueDataSourceActor *>(actor);
  881. const auto &data_kernel = device_queue_ds_actor->data_kernel_;
  882. MS_EXCEPTION_IF_NULL(data_kernel);
  883. ofs << "\t\tdata_kernel_name:" << data_kernel->fullname_with_scope()
  884. << "\tinput_number:" << AnfAlgo::GetInputTensorNum(data_kernel)
  885. << "\toutput_number:" << AnfAlgo::GetOutputTensorNum(data_kernel) << "\n";
  886. for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(data_kernel); ++i) {
  887. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(data_kernel, i, false);
  888. MS_EXCEPTION_IF_NULL(device_tensor);
  889. ofs << "\t\t\toutput_index:" << i << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
  890. << "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
  891. }
  892. } else if (actor_name.find("_HostQueueDataSourceActor") != string::npos) {
  893. // Dump the member info of host queue data source actor.
  894. const auto &host_queue_ds_actor = dynamic_cast<const HostQueueDataSourceActor *>(actor);
  895. ofs << "\t\tdata_nodes:" << host_queue_ds_actor->data_nodes_.size() << "\n";
  896. for (size_t i = 0; i < host_queue_ds_actor->data_nodes_.size(); ++i) {
  897. const auto &data_node = host_queue_ds_actor->data_nodes_[i];
  898. MS_EXCEPTION_IF_NULL(data_node);
  899. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(data_node, 0, false);
  900. MS_EXCEPTION_IF_NULL(device_tensor);
  901. ofs << "\t\t\tnode_order_number:" << i << "\tnode_name:" << data_node->fullname_with_scope()
  902. << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
  903. << "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
  904. }
  905. }
  906. ofs << "\t\toutput_data_arrows:" << actor->output_op_arrows_.size() << "\n ";
  907. for (const auto &data_arrow : actor->output_op_arrows_) {
  908. MS_EXCEPTION_IF_NULL(data_arrow);
  909. ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
  910. << "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
  911. << "\n";
  912. }
  913. }
  914. void GraphScheduler::DumpLoopCountActor(const LoopCountActor *actor, std::ofstream &ofs) const {
  915. MS_EXCEPTION_IF_NULL(actor);
  916. ofs << "\tactor_name:" << actor->GetAID().Name() << "\tloop_count:" << actor->loop_count_
  917. << "\tinput_controls_num:" << actor->input_controls_num_ << "\n";
  918. ofs << "\t\toutput_control_arrows:" << (actor->data_source_aids_.size() + actor->no_input_kernel_aids_.size())
  919. << "\n ";
  920. for (const auto &aid : actor->data_source_aids_) {
  921. ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
  922. }
  923. for (const auto &aid : actor->no_input_kernel_aids_) {
  924. ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
  925. }
  926. }
  927. void GraphScheduler::DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) const {
  928. MS_EXCEPTION_IF_NULL(actor);
  929. MS_EXCEPTION_IF_NULL(actor->device_context_);
  930. ofs << "\tactor_name:" << actor->GetAID().Name()
  931. << "\tdevice_context:" << actor->device_context_->device_context_key().ToString()
  932. << "\tinput_data_num:" << actor->input_datas_num_ << "\tinput_controls_num:" << actor->input_controls_num_
  933. << "\n";
  934. const auto &kernel = actor->kernel_;
  935. MS_EXCEPTION_IF_NULL(kernel);
  936. ofs << "\t\tkernel_name:" << kernel->fullname_with_scope() << "\tinput_number:" << AnfAlgo::GetInputTensorNum(kernel)
  937. << "\toutput_number:" << AnfAlgo::GetOutputTensorNum(kernel) << "\n";
  938. for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel); ++i) {
  939. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
  940. MS_EXCEPTION_IF_NULL(device_tensor);
  941. ofs << "\t\t\toutput_index:" << i << "\tptr:" << device_tensor->GetPtr() << "\tsize:" << device_tensor->GetSize()
  942. << "\toriginal_ref_count:" << device_tensor->original_ref_count() << "\n ";
  943. }
  944. ofs << "\t\tdevice_tensor_stores:" << actor->device_tensor_store_keys_.size() << "\n ";
  945. for (const auto &device_tensor_store_key : actor->device_tensor_store_keys_) {
  946. const auto &node = reinterpret_cast<AnfNode *>(device_tensor_store_key.second);
  947. MS_EXCEPTION_IF_NULL(node);
  948. ofs << "\t\t\tto_input_index:" << device_tensor_store_key.first
  949. << "\tfrom_node_name:" << node->fullname_with_scope() << "\n";
  950. }
  951. ofs << "\t\toutput_data_arrows:" << actor->output_op_arrows_.size() << "\n ";
  952. for (const auto &data_arrow : actor->output_op_arrows_) {
  953. MS_EXCEPTION_IF_NULL(data_arrow);
  954. ofs << "\t\t\tfrom_output_index:" << data_arrow->from_output_index_
  955. << "\tto_actor_name:" << data_arrow->to_op_id_.Name() << "\tto_input_index:" << data_arrow->to_input_index_
  956. << "\n";
  957. }
  958. ofs << "\t\toutput_control_arrows:" << actor->output_op_controls_.size() << "\n ";
  959. for (const auto &aid : actor->output_op_controls_) {
  960. ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
  961. }
  962. }
  963. } // namespace runtime
  964. } // namespace mindspore