You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_scheduler.cc 29 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/framework/graph_scheduler.h"
  17. #include "runtime/framework/actor/memory_manager_actor.h"
  18. #include "mindrt/src/actor/actormgr.h"
  19. #include "mindrt/include/async/async.h"
  20. #include "backend/session/anf_runtime_algorithm.h"
  21. #include "backend/optimizer/common/helper.h"
  22. #include "utils/config_manager.h"
  23. #include "utils/log_adapter.h"
  24. #include "utils/convert_utils.h"
  25. #include "common/trans.h"
  26. namespace mindspore {
  27. namespace runtime {
  28. namespace {
  29. bool IsDeviceQueueDSActor(const AnfNodePtr &node) {
  30. MS_EXCEPTION_IF_NULL(node);
  31. if (node->isa<CNode>() && (AnfAlgo::GetCNodeName(node) == kGetNextOpName)) {
  32. return true;
  33. }
  34. return false;
  35. }
  36. bool IsHostQueueDSActor(const AnfNodePtr &node) {
  37. MS_EXCEPTION_IF_NULL(node);
  38. if (node->isa<Parameter>() && (!AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>()))) {
  39. return true;
  40. }
  41. return false;
  42. }
  43. bool IsKernelActor(const AnfNodePtr &node) {
  44. MS_EXCEPTION_IF_NULL(node);
  45. if (node->isa<CNode>() && (AnfAlgo::GetCNodeName(node) != kGetNextOpName)) {
  46. return true;
  47. }
  48. return false;
  49. }
  50. // Judge whether the device tensor of the node is persistent or not.
  51. bool IsPersistentDeviceTensor(const AnfNodePtr &node) {
  52. MS_EXCEPTION_IF_NULL(node);
  53. if (node->isa<ValueNode>()) {
  54. return true;
  55. }
  56. if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
  57. return true;
  58. }
  59. return false;
  60. }
  61. KernelActor *FindKernelActor(const std::unordered_map<std::string, KernelActorPtr> &kernel_actors_map,
  62. const std::string &name) {
  63. auto iter = kernel_actors_map.find(name);
  64. if (iter != kernel_actors_map.end()) {
  65. return iter->second.get();
  66. }
  67. return nullptr;
  68. }
  69. DeviceQueueDataSourceActor *FindDeviceQueueDSActor(const std::vector<DataSourceActorPtr> &data_source_actors) {
  70. for (auto &actor : data_source_actors) {
  71. MS_EXCEPTION_IF_NULL(actor);
  72. if (actor->GetAID().Name().find("_DeviceQueueDataSourceActor") != string::npos) {
  73. auto device_queue_ds_actor = dynamic_cast<DeviceQueueDataSourceActor *>(actor.get());
  74. return device_queue_ds_actor;
  75. }
  76. }
  77. return nullptr;
  78. }
  79. HostQueueDataSourceActor *FindHostQueueDSActor(const std::vector<DataSourceActorPtr> &data_source_actors) {
  80. for (auto &actor : data_source_actors) {
  81. MS_EXCEPTION_IF_NULL(actor);
  82. if (actor->GetAID().Name().find("_HostQueueDataSourceActor") != string::npos) {
  83. auto device_queue_ds_actor = dynamic_cast<HostQueueDataSourceActor *>(actor.get());
  84. return device_queue_ds_actor;
  85. }
  86. }
  87. return nullptr;
  88. }
  89. // Update the reference count of device tensor by the output index of node.
  90. void UpdateRefCount(const AnfNodePtr &node, size_t output_idx) {
  91. MS_EXCEPTION_IF_NULL(node);
  92. auto device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_idx);
  93. MS_EXCEPTION_IF_NULL(device_tensor);
  94. device_tensor->IncreaseRefCount();
  95. device_tensor->ResetRefCountUsed();
  96. }
  97. // The branch processing of PrepareDataForValueNode that value type is tensor.
  98. void PrepareDataForValueNodeTensor(const ValueNodePtr &node, const ValuePtr &node_value,
  99. const DeviceContext *device_context) {
  100. MS_EXCEPTION_IF_NULL(node);
  101. MS_EXCEPTION_IF_NULL(node_value);
  102. MS_EXCEPTION_IF_NULL(device_context);
  103. std::vector<TensorPtr> tensors;
  104. TensorValueToTensor(node_value, &tensors);
  105. for (size_t i = 0; i < tensors.size(); i++) {
  106. const auto &tensor = tensors[i];
  107. if (tensor == nullptr) {
  108. MS_LOG(WARNING) << "Tensor is null";
  109. return;
  110. }
  111. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, i);
  112. MS_EXCEPTION_IF_NULL(device_tensor);
  113. // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
  114. if (device_tensor->GetPtr() != nullptr) {
  115. return;
  116. }
  117. MS_LOG(INFO) << "Prepare device data for value node: " << node->fullname_with_scope() << ", output index: " << i;
  118. tensor->set_device_address(device_tensor);
  119. // Allocate device memory.
  120. if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
  121. MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, node name: " << node->fullname_with_scope()
  122. << ", alloc size: " << device_tensor->GetSize();
  123. }
  124. // Copy data from host tensor to device.
  125. if (!device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0), LongToSize(tensor->data().nbytes()),
  126. tensor->data_type(), tensor->data_c())) {
  127. MS_LOG(EXCEPTION) << "SyncHostToDevice failed, node name: " << node->fullname_with_scope();
  128. }
  129. }
  130. }
  131. // Prepare the device data for persistent device tensor of value node.
  132. void PrepareDataForValueNode(const ValueNodePtr &node, const DeviceContext *device_context) {
  133. MS_EXCEPTION_IF_NULL(node);
  134. MS_EXCEPTION_IF_NULL(device_context);
  135. auto &node_value = node->value();
  136. MS_EXCEPTION_IF_NULL(node_value);
  137. if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
  138. // The branch processing that value type is tensor.
  139. PrepareDataForValueNodeTensor(node, node_value, device_context);
  140. } else if (node_value->isa<StringImm>()) {
  141. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, 0);
  142. MS_EXCEPTION_IF_NULL(device_tensor);
  143. // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
  144. if (device_tensor->GetPtr() != nullptr) {
  145. return;
  146. }
  147. MS_LOG(INFO) << "Prepare device data for value node: " << node->fullname_with_scope();
  148. // Allocate device memory.
  149. if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
  150. MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, node name: " << node->fullname_with_scope()
  151. << ", alloc size: " << device_tensor->GetSize();
  152. }
  153. // Copy data from value to device.
  154. auto value = GetValue<std::string>(node_value);
  155. size_t tensor_size = value.size();
  156. ShapeVector shape = {1, SizeToLong(tensor_size)};
  157. if (!device_tensor->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) {
  158. MS_LOG(EXCEPTION) << "SyncHostToDevice failed, node name: " << node->fullname_with_scope();
  159. }
  160. }
  161. }
  162. // Prepare the device data for persistent device tensor of weight node from host tensor.
  163. void PrepareDataForWeightNode(const AnfNodePtr &node, const TensorPtr &tensor, const DeviceContext *device_context) {
  164. MS_EXCEPTION_IF_NULL(node);
  165. MS_EXCEPTION_IF_NULL(tensor);
  166. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, 0);
  167. MS_EXCEPTION_IF_NULL(device_tensor);
  168. // If the ptr of device tensor is not nullptr, it indicates that the device data has been prepared.
  169. if (device_tensor->GetPtr() != nullptr) {
  170. return;
  171. }
  172. MS_LOG(INFO) << "Prepare device data for weight node: " << node->fullname_with_scope();
  173. tensor->set_device_address(device_tensor);
  174. // Allocate device memory.
  175. if (!device_context->AllocateMemory(device_tensor.get(), device_tensor->GetSize())) {
  176. MS_LOG(EXCEPTION) << "Device memory isn't enough and alloc failed, node name: " << node->fullname_with_scope()
  177. << ", alloc size: " << device_tensor->GetSize();
  178. }
  179. // Copy data from host tensor to device.
  180. if (!device_tensor->SyncHostToDevice(trans::GetRuntimePaddingShape(node, 0), LongToSize(tensor->data().nbytes()),
  181. tensor->data_type(), tensor->data_c())) {
  182. MS_LOG(EXCEPTION) << "SyncHostToDevice failed, node name: " << node->fullname_with_scope();
  183. }
  184. }
  185. BaseRef CreateOutputTensor(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph,
  186. const std::vector<tensor::TensorPtr> &input_tensors) {
  187. auto &node = node_output_pair.first;
  188. auto output_index = node_output_pair.second;
  189. MS_EXCEPTION_IF_NULL(node);
  190. if (node->isa<ValueNode>()) {
  191. // If node is a value node, return the value.
  192. auto value_node = node->cast<ValueNodePtr>();
  193. MS_EXCEPTION_IF_NULL(value_node);
  194. return value_node->value();
  195. } else if (node->isa<Parameter>()) {
  196. // If node is a parameter node, return tensor from input_tensors.
  197. MS_EXCEPTION_IF_NULL(graph);
  198. const auto &input_nodes = graph->inputs();
  199. auto iter = find(input_nodes.begin(), input_nodes.end(), node);
  200. if (iter == input_nodes.end()) {
  201. MS_LOG(EXCEPTION) << "Parameter node: " << node->fullname_with_scope() << " is not exist.";
  202. }
  203. auto position = IntToSize(std::distance(input_nodes.begin(), iter));
  204. return input_tensors[position];
  205. } else {
  206. // Create tensor.
  207. TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
  208. if (type_id == kTypeUnknown) {
  209. type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
  210. }
  211. std::vector<int64_t> temp_shape;
  212. auto shape = AnfAlgo::GetOutputInferShape(node, output_index);
  213. (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape));
  214. auto tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
  215. MS_EXCEPTION_IF_NULL(tensor);
  216. tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index));
  217. // Set device address to tensor.
  218. const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(node, output_index);
  219. MS_EXCEPTION_IF_NULL(device_tensor);
  220. tensor->set_device_address(device_tensor);
  221. return tensor;
  222. }
  223. }
  224. BaseRef CreateOutputTensors(const AnfNodePtr &output_node, const KernelGraphPtr &graph,
  225. const std::vector<tensor::TensorPtr> &input_tensors) {
  226. MS_EXCEPTION_IF_NULL(output_node);
  227. auto item_with_index = AnfAlgo::VisitKernelWithReturnType(output_node, 0);
  228. MS_EXCEPTION_IF_NULL(item_with_index.first);
  229. // Special handle for make tuple.
  230. if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
  231. auto cnode = item_with_index.first->cast<CNodePtr>();
  232. MS_EXCEPTION_IF_NULL(cnode);
  233. VectorRef ret;
  234. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  235. auto out = CreateOutputTensors(cnode->input(i), graph, input_tensors);
  236. ret.push_back(out);
  237. }
  238. return ret;
  239. }
  240. // If the node return nothing, return an empty vectorRef.
  241. if (AnfAlgo::GetOutputTensorNum(item_with_index.first) == 0) {
  242. return VectorRef();
  243. }
  244. return CreateOutputTensor(item_with_index, graph, input_tensors);
  245. }
  246. } // namespace
  247. void GraphScheduler::Initialize() {
  248. if (init_) {
  249. return;
  250. }
  251. init_ = true;
  252. // Create memory manager actor.
  253. auto memory_manager_actor = std::make_shared<MemoryManagerActor>();
  254. MS_EXCEPTION_IF_NULL(memory_manager_actor);
  255. memory_manager_aid_ = memory_manager_actor->GetAID();
  256. // Schedule memory manager actor, bind single thread to response to memory alloc and free quickly.
  257. auto base_actor = static_cast<ActorReference>(memory_manager_actor);
  258. auto actorMgr = ActorMgr::GetActorMgrRef();
  259. MS_EXCEPTION_IF_NULL(actorMgr);
  260. (void)actorMgr->Spawn(base_actor, false);
  261. }
  262. ActorSet *GraphScheduler::Transform(const KernelGraphPtr &graph, const DeviceContext *device_context,
  263. const std::vector<tensor::TensorPtr> *input_tensors,
  264. GraphExecutionStrategy strategy) {
  265. PersistDeviceTensor(graph);
  266. auto actor_set = Build(graph, device_context);
  267. graph_to_actors_.emplace(graph, actor_set);
  268. Link(actor_set.get(), graph, strategy);
  269. return actor_set.get();
  270. }
  271. void GraphScheduler::Schedule(const ActorSet *actor_set) {
  272. MS_EXCEPTION_IF_NULL(actor_set);
  273. auto actorMgr = ActorMgr::GetActorMgrRef();
  274. MS_EXCEPTION_IF_NULL(actorMgr);
  275. // Schedule dats source actors.
  276. for (auto &data_source_actor : actor_set->data_source_actors_) {
  277. MS_EXCEPTION_IF_NULL(data_source_actor);
  278. auto base_actor = static_cast<ActorReference>(data_source_actor);
  279. (void)actorMgr->Spawn(base_actor);
  280. }
  281. // Schedule kernel actors.
  282. for (auto &kernel_actor : actor_set->kernel_actors_) {
  283. MS_EXCEPTION_IF_NULL(kernel_actor);
  284. auto base_actor = static_cast<ActorReference>(kernel_actor);
  285. (void)actorMgr->Spawn(base_actor);
  286. }
  287. // Schedule loop count actor.
  288. if (actor_set->loop_count_actor_ != nullptr) {
  289. auto base_actor = static_cast<ActorReference>(actor_set->loop_count_actor_);
  290. (void)actorMgr->Spawn(base_actor);
  291. }
  292. }
  293. void GraphScheduler::PrepareRun(const KernelGraphPtr &graph, const DeviceContext *device_context,
  294. const std::vector<TensorPtr> *input_tensors, VectorRef *const &outputs) {
  295. MS_EXCEPTION_IF_NULL(graph);
  296. MS_EXCEPTION_IF_NULL(device_context);
  297. MS_EXCEPTION_IF_NULL(input_tensors);
  298. MS_EXCEPTION_IF_NULL(outputs);
  299. // 1.Prepare the data of device tensor store(value nodes of graph).
  300. for (const auto &value_node : graph->graph_value_nodes()) {
  301. PrepareDataForValueNode(value_node, device_context);
  302. }
  303. // 1.Prepare the data of device tensor store(weights of graph), and fill the host tensors for non weighted parameters.
  304. std::vector<TensorPtr> host_tensors;
  305. const auto &input_nodes = graph->input_nodes();
  306. for (size_t i = 0; i < input_nodes.size(); ++i) {
  307. const auto &input_node = input_nodes[i];
  308. const auto &input_tensor = (*input_tensors)[i];
  309. MS_EXCEPTION_IF_NULL(input_node);
  310. if (IsPersistentDeviceTensor(input_node)) {
  311. // Prepare the device data for weights.
  312. PrepareDataForWeightNode(input_node, input_tensor, device_context);
  313. } else {
  314. // Fill the host tensors for non weighted parameters.
  315. host_tensors.emplace_back(input_tensor);
  316. }
  317. }
  318. // 2.Prepare the data of host tensor queue(non weighted parameters of graph).
  319. const auto &host_tensor_queue = FetchHostQueue(graph);
  320. MS_EXCEPTION_IF_NULL(host_tensor_queue);
  321. host_tensor_queue->PushData(host_tensors);
  322. // 3.Prepare the output tensor of graph.
  323. for (const auto &output_node : graph->outputs()) {
  324. MS_EXCEPTION_IF_NULL(output_node);
  325. MS_LOG(INFO) << "Create node output: " << output_node->fullname_with_scope();
  326. outputs->emplace_back(CreateOutputTensors(output_node, graph, *input_tensors));
  327. }
  328. }
  329. bool GraphScheduler::Run(const ActorSet *actor_set, GraphExecutionStrategy strategy) {
  330. MS_EXCEPTION_IF_NULL(actor_set);
  331. // Construct OpContext.
  332. OpContext<DeviceTensor> op_context;
  333. auto sequential_num = uuids::RandomBasedGenerator::GenerateRandomUuid();
  334. op_context.sequential_num_ = &sequential_num;
  335. Promise<int> result;
  336. op_context.results_->push_back(result);
  337. // Trigger no input kernel actor running.
  338. for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
  339. MS_EXCEPTION_IF_NULL(no_input_kernel_actor);
  340. Async(no_input_kernel_actor->GetAID(), &KernelActor::RunOpControl, nullptr, &op_context);
  341. }
  342. // Trigger data source actor running.
  343. for (auto &data_source_actor : actor_set->data_source_actors_) {
  344. MS_EXCEPTION_IF_NULL(data_source_actor);
  345. Async(data_source_actor->GetAID(), &DataSourceActor::FetchData, &op_context);
  346. }
  347. // Trigger kernel actor running in the step execution strategy.
  348. if (strategy == GraphExecutionStrategy::kStep) {
  349. for (auto &kernel_actor : actor_set->kernel_actors_) {
  350. MS_EXCEPTION_IF_NULL(kernel_actor);
  351. Async(kernel_actor->GetAID(), &KernelActor::RunOpControl, nullptr, &op_context);
  352. }
  353. }
  354. // Get the run result.
  355. auto result_future = result.GetFuture();
  356. result_future.Wait();
  357. if (!result_future.IsOK()) {
  358. return false;
  359. }
  360. return true;
  361. }
  362. ActorSet *GraphScheduler::Fetch(const KernelGraphPtr &graph) const {
  363. MS_EXCEPTION_IF_NULL(graph);
  364. auto iter = graph_to_actors_.find(graph);
  365. if (iter != graph_to_actors_.end()) {
  366. return iter->second.get();
  367. } else {
  368. MS_LOG(ERROR) << "Can't find the actors map of graph: " << graph->ToString();
  369. return nullptr;
  370. }
  371. }
  372. ActorSetPtr GraphScheduler::Build(const KernelGraphPtr &graph, const DeviceContext *device_context) {
  373. auto actor_set = std::make_shared<ActorSet>();
  374. MS_EXCEPTION_IF_NULL(actor_set);
  375. auto data_source_actors = BuildDataSourceActor(graph, device_context);
  376. actor_set->data_source_actors_.swap(data_source_actors);
  377. auto kernel_actors = BuildKernelActor(graph, device_context);
  378. actor_set->kernel_actors_.swap(kernel_actors);
  379. auto loop_count_actor = BuildLoopCountActor(graph);
  380. actor_set->loop_count_actor_ = loop_count_actor;
  381. return actor_set;
  382. }
  383. void GraphScheduler::Link(ActorSet *actor_set, const KernelGraphPtr &graph, GraphExecutionStrategy strategy) {
  384. MS_EXCEPTION_IF_NULL(actor_set);
  385. MS_EXCEPTION_IF_NULL(graph);
  386. std::unordered_map<std::string, KernelActorPtr> kernel_actors_temp_map;
  387. for (auto &actor : actor_set->kernel_actors_) {
  388. MS_EXCEPTION_IF_NULL(actor);
  389. kernel_actors_temp_map.emplace(actor->GetAID().Name(), actor);
  390. }
  391. // Foreach the execution order to link the actors.
  392. auto execution_order = graph->execution_order();
  393. for (auto &kernel : execution_order) {
  394. if (!IsKernelActor(kernel)) {
  395. continue;
  396. }
  397. auto kernel_actor = FindKernelActor(kernel_actors_temp_map, kernel->fullname_with_scope());
  398. // Link the control arrows of kernel actor.
  399. LinkControlArrowForKernelActor(kernel_actor, actor_set->loop_count_actor_.get(), graph, strategy);
  400. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
  401. KernelWithIndex from_kernel_with_output_idx = AnfAlgo::GetPrevNodeOutput(kernel, i);
  402. KernelWithIndex to_kernel_with_input_idx = std::make_pair(kernel, i);
  403. auto from_kernel = from_kernel_with_output_idx.first;
  404. if (IsDeviceQueueDSActor(from_kernel)) {
  405. // Link the data arrows of device queue data source actor.
  406. auto from_actor = FindDeviceQueueDSActor(actor_set->data_source_actors_);
  407. LinkDataArrowForDeviceDSActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
  408. } else if (IsHostQueueDSActor(from_kernel)) {
  409. // Link the data arrows of host queue data source actor.
  410. auto from_actor = FindHostQueueDSActor(actor_set->data_source_actors_);
  411. LinkDataArrowForHostDSActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
  412. } else {
  413. // Link the data arrows of kernel actor.
  414. auto from_actor = FindKernelActor(kernel_actors_temp_map, from_kernel->fullname_with_scope());
  415. LinkDataArrowForKernelActor(from_actor, kernel_actor, from_kernel_with_output_idx, to_kernel_with_input_idx);
  416. }
  417. }
  418. }
  419. // BuildNoInputKernelActor depends on whether kernel actors have input, so must be behind the link of kernel actors.
  420. auto no_input_kernel_actors = BuildNoInputKernelActor(graph);
  421. actor_set->no_input_kernel_actors_.swap(no_input_kernel_actors);
  422. // Link the control arrows of loop count actor, which depends on the no input kernel actors.
  423. LinkControlArrowForLoopCountActor(actor_set->loop_count_actor_.get(), graph);
  424. }
  425. std::vector<DataSourceActorPtr> GraphScheduler::BuildDataSourceActor(const KernelGraphPtr &graph,
  426. const DeviceContext *device_context) {
  427. MS_EXCEPTION_IF_NULL(graph);
  428. std::vector<DataSourceActorPtr> data_source_actors;
  429. // Build host queue data source actor.
  430. HostQueueDSActorPtr host_queue_ds_actor = nullptr;
  431. for (auto &input_node : graph->input_nodes()) {
  432. MS_EXCEPTION_IF_NULL(input_node);
  433. if (IsHostQueueDSActor(input_node)) {
  434. if (host_queue_ds_actor == nullptr) {
  435. auto actor_name = graph->ToString() + "_" + "HostQueueDataSourceActor";
  436. MS_LOG(INFO) << "Create host queue data source actor: " << actor_name;
  437. auto host_queue = std::make_shared<HostTensorQueue>();
  438. graph_to_host_queue_.emplace(graph, host_queue);
  439. host_queue_ds_actor =
  440. std::make_shared<HostQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_, host_queue);
  441. data_source_actors.emplace_back(host_queue_ds_actor);
  442. }
  443. host_queue_ds_actor->data_nodes_.emplace_back(input_node);
  444. }
  445. }
  446. // Build device queue data source actor.
  447. auto execution_order = graph->execution_order();
  448. auto iter = std::find_if(execution_order.begin(), execution_order.end(),
  449. [](const CNodePtr &node) { return IsDeviceQueueDSActor(node); });
  450. if (iter != execution_order.end()) {
  451. auto actor_name = graph->ToString() + "_" + "DeviceQueueDataSourceActor";
  452. MS_LOG(INFO) << "Create queue data source actor: " << actor_name;
  453. auto device_queue_ds_actor =
  454. std::make_shared<DeviceQueueDataSourceActor>(actor_name, 1, device_context, memory_manager_aid_);
  455. MS_EXCEPTION_IF_NULL(device_queue_ds_actor);
  456. data_source_actors.emplace_back(device_queue_ds_actor);
  457. device_queue_ds_actor->data_kernel_ = *iter;
  458. }
  459. return data_source_actors;
  460. }
  461. std::vector<KernelActorPtr> GraphScheduler::BuildKernelActor(const KernelGraphPtr &graph,
  462. const DeviceContext *device_context) {
  463. MS_EXCEPTION_IF_NULL(graph);
  464. std::vector<KernelActorPtr> kernel_actors;
  465. auto execution_order = graph->execution_order();
  466. for (auto &kernel : execution_order) {
  467. if (IsKernelActor(kernel)) {
  468. auto kernel_actor =
  469. std::make_shared<KernelActor>(kernel->fullname_with_scope(), kernel, device_context, memory_manager_aid_);
  470. MS_EXCEPTION_IF_NULL(kernel_actor);
  471. kernel_actors.emplace_back(kernel_actor);
  472. }
  473. }
  474. return kernel_actors;
  475. }
  476. std::vector<KernelActorPtr> GraphScheduler::BuildNoInputKernelActor(const KernelGraphPtr &graph) {
  477. MS_EXCEPTION_IF_NULL(graph);
  478. std::vector<KernelActorPtr> no_input_kernel_actors;
  479. auto actor_set = Fetch(graph);
  480. MS_EXCEPTION_IF_NULL(actor_set);
  481. for (auto &kernel_actor : actor_set->kernel_actors_) {
  482. MS_EXCEPTION_IF_NULL(kernel_actor);
  483. if ((kernel_actor->input_datas_num_ == 0) && (kernel_actor->input_controls_num_ == 0)) {
  484. no_input_kernel_actors.emplace_back(kernel_actor);
  485. }
  486. }
  487. return no_input_kernel_actors;
  488. }
  489. LoopCountActorPtr GraphScheduler::BuildLoopCountActor(const KernelGraphPtr &graph) {
  490. MS_EXCEPTION_IF_NULL(graph);
  491. auto loop_count = ConfigManager::GetInstance().iter_num();
  492. auto actor_name = graph->ToString() + "_" + "LoopCountActor";
  493. auto loop_count_actor = std::make_shared<LoopCountActor>(actor_name, loop_count);
  494. MS_EXCEPTION_IF_NULL(loop_count_actor);
  495. return loop_count_actor;
  496. }
  497. void GraphScheduler::LinkDataArrowForDeviceDSActor(DeviceQueueDataSourceActor *from_actor, KernelActor *to_actor,
  498. KernelWithIndex from_kernel_with_output_idx,
  499. KernelWithIndex to_kernel_with_input_idx) {
  500. MS_EXCEPTION_IF_NULL(from_actor);
  501. MS_EXCEPTION_IF_NULL(to_actor);
  502. auto from_kernel = from_kernel_with_output_idx.first;
  503. MS_EXCEPTION_IF_NULL(from_kernel);
  504. auto from_output_index = from_kernel_with_output_idx.second;
  505. auto to_input_index = to_kernel_with_input_idx.second;
  506. auto to_aid = to_actor->GetAID();
  507. auto op_arrow = std::make_shared<OpArrow>(from_output_index, to_aid, to_input_index);
  508. from_actor->output_op_arrows_.emplace_back(op_arrow);
  509. to_actor->input_datas_num_++;
  510. // Update the reference count of device tensor.
  511. UpdateRefCount(from_kernel, from_output_index);
  512. }
  513. void GraphScheduler::LinkDataArrowForHostDSActor(HostQueueDataSourceActor *from_actor, KernelActor *to_actor,
  514. KernelWithIndex from_kernel_with_output_idx,
  515. KernelWithIndex to_kernel_with_input_idx) {
  516. MS_EXCEPTION_IF_NULL(from_actor);
  517. MS_EXCEPTION_IF_NULL(to_actor);
  518. auto from_kernel = from_kernel_with_output_idx.first;
  519. MS_EXCEPTION_IF_NULL(from_kernel);
  520. auto from_output_index = from_kernel_with_output_idx.second;
  521. auto to_input_index = to_kernel_with_input_idx.second;
  522. auto data_nodes = from_actor->data_nodes_;
  523. auto iter = find(data_nodes.begin(), data_nodes.end(), from_kernel);
  524. if (iter == data_nodes.end()) {
  525. MS_LOG(EXCEPTION) << "Parameter node: " << from_kernel->fullname_with_scope() << " is not exist.";
  526. }
  527. auto position = IntToSize(std::distance(data_nodes.begin(), iter));
  528. auto to_aid = to_actor->GetAID();
  529. auto op_arrow = std::make_shared<OpArrow>(position, to_aid, to_input_index);
  530. from_actor->output_op_arrows_.emplace_back(op_arrow);
  531. to_actor->input_datas_num_++;
  532. // Update the reference count of device tensor.
  533. UpdateRefCount(from_kernel, from_output_index);
  534. }
  535. void GraphScheduler::LinkDataArrowForKernelActor(KernelActor *from_actor, KernelActor *to_actor,
  536. KernelWithIndex from_kernel_with_output_idx,
  537. KernelWithIndex to_kernel_with_input_idx) {
  538. MS_EXCEPTION_IF_NULL(to_actor);
  539. auto from_kernel = from_kernel_with_output_idx.first;
  540. MS_EXCEPTION_IF_NULL(from_kernel);
  541. auto from_output_index = from_kernel_with_output_idx.second;
  542. auto to_input_index = to_kernel_with_input_idx.second;
  543. if (IsPersistentDeviceTensor(from_kernel)) {
  544. to_actor->device_tensor_store_keys_.emplace_back(to_input_index, static_cast<void *>(from_kernel.get()));
  545. } else if (IsKernelActor(from_kernel)) {
  546. MS_EXCEPTION_IF_NULL(from_actor);
  547. auto to_aid = to_actor->GetAID();
  548. auto op_arrow = std::make_shared<OpArrow>(from_output_index, to_aid, to_input_index);
  549. from_actor->output_op_arrows_.emplace_back(op_arrow);
  550. to_actor->input_datas_num_++;
  551. // Update the reference count of device tensor.
  552. UpdateRefCount(from_kernel, from_output_index);
  553. }
  554. }
  555. void GraphScheduler::LinkControlArrowForKernelActor(KernelActor *from_actor, LoopCountActor *to_actor,
  556. const KernelGraphPtr &graph, GraphExecutionStrategy strategy) {
  557. MS_EXCEPTION_IF_NULL(from_actor);
  558. MS_EXCEPTION_IF_NULL(to_actor);
  559. MS_EXCEPTION_IF_NULL(graph);
  560. if (strategy == GraphExecutionStrategy::kStep) {
  561. from_actor->input_controls_num_++;
  562. }
  563. if (opt::IsNotRealUsedByOthers(graph, from_actor->kernel_)) {
  564. auto to_aid = to_actor->GetAID();
  565. from_actor->output_op_controls_.emplace_back(to_aid);
  566. to_actor->input_controls_num_++;
  567. }
  568. }
  569. void GraphScheduler::LinkControlArrowForLoopCountActor(LoopCountActor *loop_count_actor, const KernelGraphPtr &graph) {
  570. MS_EXCEPTION_IF_NULL(graph);
  571. MS_EXCEPTION_IF_NULL(loop_count_actor);
  572. auto actor_set = Fetch(graph);
  573. MS_EXCEPTION_IF_NULL(actor_set);
  574. // Set the source data actor.
  575. for (auto &data_source_actor : actor_set->data_source_actors_) {
  576. MS_EXCEPTION_IF_NULL(data_source_actor);
  577. loop_count_actor->data_source_aids_.emplace_back(data_source_actor->GetAID());
  578. }
  579. // Set the no input kernel actor.
  580. for (auto &no_input_kernel_actor : actor_set->no_input_kernel_actors_) {
  581. MS_EXCEPTION_IF_NULL(no_input_kernel_actor);
  582. loop_count_actor->no_input_kernel_aids_.emplace_back(no_input_kernel_actor->GetAID());
  583. }
  584. }
  585. void GraphScheduler::PersistDeviceTensor(const KernelGraphPtr &graph) {
  586. MS_EXCEPTION_IF_NULL(graph);
  587. for (auto &value_node : graph->graph_value_nodes()) {
  588. MS_EXCEPTION_IF_NULL(value_node);
  589. auto device_tensor = AnfAlgo::GetMutableOutputAddr(value_node, 0);
  590. DeviceTensorStore::GetInstance().Insert(value_node.get(), device_tensor);
  591. device_tensor->set_ref_count(SIZE_MAX);
  592. device_tensor->ResetRefCountUsed();
  593. }
  594. for (auto &input_node : graph->input_nodes()) {
  595. MS_EXCEPTION_IF_NULL(input_node);
  596. if (IsPersistentDeviceTensor(input_node)) {
  597. auto device_tensor = AnfAlgo::GetMutableOutputAddr(input_node, 0);
  598. DeviceTensorStore::GetInstance().Insert(input_node.get(), device_tensor);
  599. device_tensor->set_ref_count(SIZE_MAX);
  600. device_tensor->ResetRefCountUsed();
  601. }
  602. }
  603. }
  604. HostTensorQueue *GraphScheduler::FetchHostQueue(const KernelGraphPtr &graph) const {
  605. MS_EXCEPTION_IF_NULL(graph);
  606. const auto &iter = graph_to_host_queue_.find(graph);
  607. if (iter != graph_to_host_queue_.end()) {
  608. return iter->second.get();
  609. } else {
  610. MS_LOG(ERROR) << "Can't find the host tensor queue map of graph: " << graph->ToString();
  611. return nullptr;
  612. }
  613. }
  614. } // namespace runtime
  615. } // namespace mindspore