You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gpu_kernel_runtime.cc 24 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "device/gpu/gpu_kernel_runtime.h"
  17. #include "device/gpu/gpu_device_address.h"
  18. #include "device/gpu/cuda_driver.h"
  19. #include "device/gpu/gpu_buffer_mgr.h"
  20. #include "device/gpu/gpu_device_manager.h"
  21. #include "device/gpu/gpu_memory_allocator.h"
  22. #include "device/gpu/distribution/collective_init.h"
  23. #include "utils/convert_utils.h"
  24. #include "utils/context/ms_context.h"
  25. #include "device/kernel_runtime_manager.h"
  26. #include "device/gpu/gpu_common.h"
  27. #include "common/utils.h"
  28. #include "device/gpu/gpu_memory_manager.h"
  29. #include "kernel/common_utils.h"
  30. #include "device/gpu/gpu_memory_copy_manager.h"
  31. namespace mindspore {
  32. namespace device {
  33. namespace gpu {
  34. using mindspore::device::memswap::MemSwapManager;
  35. using mindspore::device::memswap::SwapKind;
  36. bool GPUKernelRuntime::SyncStream() { return GPUDeviceManager::GetInstance().SyncStream(stream_); }
  37. bool GPUKernelRuntime::Init() {
  38. if (device_init_ == true) {
  39. return true;
  40. }
  41. auto ret = InitDevice();
  42. if (!ret) {
  43. MS_LOG(ERROR) << "InitDevice error.";
  44. return ret;
  45. }
  46. mem_manager_ = std::make_shared<GPUMemoryManager>();
  47. MS_EXCEPTION_IF_NULL(mem_manager_);
  48. mem_manager_->MallocDeviceMemory();
  49. const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
  50. bool collective_inited = CollectiveInitializer::instance().collective_inited();
  51. if (collective_inited && collective_handle_ != nullptr) {
  52. auto init_nccl_comm_funcptr =
  53. reinterpret_cast<InitNCCLComm>(dlsym(const_cast<void *>(collective_handle_), "InitNCCLComm"));
  54. MS_EXCEPTION_IF_NULL(init_nccl_comm_funcptr);
  55. (*init_nccl_comm_funcptr)();
  56. }
  57. device_init_ = true;
  58. return ret;
  59. }
  60. DeviceAddressPtr GPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
  61. TypeId type_id) {
  62. return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id);
  63. }
  64. bool GPUKernelRuntime::InitDevice() {
  65. if (GPUDeviceManager::GetInstance().device_count() <= 0) {
  66. MS_LOG(ERROR) << "No GPU device found.";
  67. return false;
  68. }
  69. const void *collective_handle_ = CollectiveInitializer::instance().collective_handle();
  70. bool collective_inited = CollectiveInitializer::instance().collective_inited();
  71. if (collective_inited && collective_handle_ != nullptr) {
  72. auto get_local_rank_funcptr =
  73. reinterpret_cast<GetLocalRankId>(dlsym(const_cast<void *>(collective_handle_), "local_rank_id"));
  74. MS_EXCEPTION_IF_NULL(get_local_rank_funcptr);
  75. device_id_ = IntToUint((*get_local_rank_funcptr)());
  76. }
  77. if (!GPUDeviceManager::GetInstance().is_device_id_init()) {
  78. if (!GPUDeviceManager::GetInstance().set_cur_device_id(device_id_)) {
  79. MS_LOG(ERROR) << "Failed to set current device to " << SizeToInt(device_id_);
  80. return false;
  81. }
  82. }
  83. GPUDeviceManager::GetInstance().InitDevice();
  84. stream_ = GPUDeviceManager::GetInstance().default_stream();
  85. if (stream_ == nullptr) {
  86. MS_LOG(ERROR) << "No default CUDA stream found.";
  87. return false;
  88. }
  89. return true;
  90. }
  91. void GPUKernelRuntime::ReleaseDeviceRes() {
  92. // For dataset mode.
  93. if (GpuBufferMgr::GetInstance().IsInit()) {
  94. if (!GpuBufferMgr::GetInstance().IsClosed()) {
  95. if (!GpuBufferMgr::GetInstance().CloseNotify()) {
  96. MS_LOG(EXCEPTION) << "Could not close gpu data queue.";
  97. }
  98. }
  99. CHECK_OP_RET_WITH_EXCEPT(GpuBufferMgr::GetInstance().Destroy(), "Could not destroy gpu data queue.");
  100. }
  101. // destroy remaining memory swap events and free host memory
  102. for (auto &item : mem_swap_map_) {
  103. auto &mem_swap_manager = item.second;
  104. MS_EXCEPTION_IF_NULL(mem_swap_manager);
  105. if (mem_swap_manager->trigger_swap()) {
  106. mem_swap_manager->ClearSwapQueue();
  107. mem_swap_manager->ReleaseHostPinnedMem();
  108. }
  109. }
  110. GPUDeviceManager::GetInstance().ReleaseDevice();
  111. if (mem_manager_ != nullptr) {
  112. mem_manager_->FreeDeviceMemory();
  113. }
  114. kernel::KernelMeta::GetInstance()->RemoveKernelCache();
  115. }
  116. void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {
  117. auto context_ptr = MsContext::GetInstance();
  118. MS_EXCEPTION_IF_NULL(context_ptr);
  119. MS_EXCEPTION_IF_NULL(mem_manager_);
  120. mem_manager_->ResetDynamicMemory();
  121. AssignStaticMemoryInput(graph);
  122. AssignStaticMemoryValueNode(graph);
  123. bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool();
  124. if (is_enable_dynamic_mem) {
  125. // Use the dynamic memory pool.
  126. InitKernelRefCount(graph);
  127. InitKernelOutputAddress(graph);
  128. } else {
  129. AssignDynamicMemory(graph);
  130. }
  131. }
  132. bool GPUKernelRuntime::Run(session::KernelGraph *graph) {
  133. bool ret = true;
  134. auto context_ptr = MsContext::GetInstance();
  135. MS_EXCEPTION_IF_NULL(context_ptr);
  136. bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool();
  137. bool is_enable_pynative_infer = context_ptr->enable_pynative_infer();
  138. auto iter = mem_swap_map_.find(graph);
  139. if (iter == mem_swap_map_.end()) {
  140. GPUMemCopyManagerPtr gpu_mem_copy_manager = std::make_shared<GPUMemCopyManager>();
  141. iter = mem_swap_map_.emplace(graph, std::make_shared<MemSwapManager>(gpu_mem_copy_manager)).first;
  142. }
  143. mem_swap_manager_ = iter->second;
  144. MS_EXCEPTION_IF_NULL(mem_swap_manager_);
  145. struct timeval start_time, end_time;
  146. (void)gettimeofday(&start_time, nullptr);
  147. if (is_enable_dynamic_mem && !is_enable_pynative_infer) {
  148. while (!LaunchKernelDynamic(graph)) {
  149. ClearKernelOutputAddress(graph);
  150. if (!mem_swap_manager_->mem_swap_init()) {
  151. mem_swap_manager_->Init(graph);
  152. }
  153. if (!mem_swap_manager_->RetreatSwapInfo()) {
  154. return false;
  155. }
  156. }
  157. } else {
  158. ret = LaunchKernel(graph);
  159. }
  160. (void)gettimeofday(&end_time, nullptr);
  161. const uint64_t kUSecondInSecond = 1000000;
  162. uint64_t cost = kUSecondInSecond * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
  163. cost += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
  164. MS_LOG(DEBUG) << "kernel runtime run graph in " << cost << " us";
  165. return ret;
  166. }
  167. void GPUKernelRuntime::InitKernelRefCount(const session::KernelGraph *graph) {
  168. MS_EXCEPTION_IF_NULL(graph);
  169. MemReuseUtilPtr mem_reuse_util_ptr = std::make_shared<memreuse::MemReuseUtil>();
  170. MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr);
  171. // Init the kernel reference count.
  172. if (!mem_reuse_util_ptr->InitDynamicKernelRef(graph)) {
  173. MS_LOG(EXCEPTION) << "Init kernel reference count failed";
  174. }
  175. mem_reuse_util_ptr->SetKernelDefMap();
  176. mem_reuse_util_ptr->SetReuseRefCount();
  177. // Can't free the device address of graph output, so set the reference count of graph output specially.
  178. mem_reuse_util_ptr->SetGraphOutputRefCount();
  179. auto graph_id = graph->graph_id();
  180. mem_reuse_util_map_[graph_id] = mem_reuse_util_ptr;
  181. }
  182. void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph) {
  183. MS_EXCEPTION_IF_NULL(graph);
  184. auto &kernels = graph->execution_order();
  185. for (const auto &kernel : kernels) {
  186. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  187. MS_EXCEPTION_IF_NULL(kernel_mod);
  188. auto output_sizes = kernel_mod->GetOutputSizeList();
  189. for (size_t i = 0; i < output_sizes.size(); ++i) {
  190. if (AnfAlgo::OutputAddrExist(kernel, i)) {
  191. continue;
  192. }
  193. std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
  194. auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
  195. auto device_address = CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
  196. AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
  197. }
  198. }
  199. }
  200. void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *graph) {
  201. MS_EXCEPTION_IF_NULL(graph);
  202. auto &kernels = graph->execution_order();
  203. for (const auto &kernel : kernels) {
  204. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  205. MS_EXCEPTION_IF_NULL(kernel_mod);
  206. auto output_sizes = kernel_mod->GetOutputSizeList();
  207. for (size_t i = 0; i < output_sizes.size(); ++i) {
  208. if (!AnfAlgo::OutputAddrExist(kernel, i)) {
  209. continue;
  210. }
  211. auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i);
  212. if (device_address->ptr_) {
  213. mem_manager_->FreeMemFromMemPool(device_address);
  214. }
  215. device_address->set_status(DeviceAddressStatus::kInDevice);
  216. }
  217. }
  218. }
  219. bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) {
  220. MS_EXCEPTION_IF_NULL(graph);
  221. auto graph_id = graph->graph_id();
  222. auto mem_reuse_util_ptr = mem_reuse_util_map_[graph_id];
  223. MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr);
  224. // Reset the reference count.
  225. mem_reuse_util_ptr->ResetDynamicUsedRefCount();
  226. // The inputs and outputs memory of communication kernel need be continuous, so separate processing.
  227. AllocCommunicationOpDynamicRes(graph);
  228. auto &kernels = graph->execution_order();
  229. for (const auto &kernel : kernels) {
  230. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  231. MS_EXCEPTION_IF_NULL(kernel_mod);
  232. AddressPtrList kernel_inputs;
  233. AddressPtrList kernel_workspaces;
  234. AddressPtrList kernel_outputs;
  235. auto ret = AllocKernelDynamicRes(*kernel_mod, kernel, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
  236. if (!ret) {
  237. return false;
  238. }
  239. if (!kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_)) {
  240. MS_LOG(EXCEPTION) << "Launch kernel failed.";
  241. }
  242. FreeKernelDynamicRes(kernel, kernel_workspaces, graph_id);
  243. if (mem_swap_manager_->trigger_swap() && mem_swap_manager_->QueryKernelTriggerSwap(kernel)) {
  244. CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed.");
  245. if (!AddMemSwapTask(kernel)) {
  246. return false;
  247. }
  248. }
  249. if (mem_swap_manager_->trigger_swap()) {
  250. mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost);
  251. }
  252. }
  253. CHECK_OP_RET_WITH_EXCEPT(SyncStream(), "SyncStream failed.");
  254. if (mem_swap_manager_->trigger_swap()) {
  255. mem_swap_manager_->ClearSwapQueue();
  256. }
  257. return true;
  258. }
  259. bool GPUKernelRuntime::AddMemSwapTask(const AnfNodePtr &kernel) {
  260. auto &mem_swap_info_list = mem_swap_manager_->QueryKernelMemSwapInfo(kernel);
  261. for (auto &mem_swap_info : mem_swap_info_list) {
  262. auto &kernel_exec_info = mem_swap_manager_->SearchKernelExecutionInfo(mem_swap_info.kernel_);
  263. const HostAddress &host_address = kernel_exec_info.host_addrs_[mem_swap_info.output_idx_];
  264. auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_);
  265. if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) {
  266. mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address);
  267. } else if (mem_swap_info.swap_kind_ == SwapKind::kHostToDevice) {
  268. auto status = device_address->status();
  269. if (status == DeviceAddressStatus::kInDeviceToHost) {
  270. mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_);
  271. device_address->set_status(DeviceAddressStatus::kInDevice);
  272. } else if (status == DeviceAddressStatus::kInHost) {
  273. if (!device_address->ptr_ && !AttemptMallocMem(device_address, device_address->size_)) {
  274. return false;
  275. }
  276. if (!mem_swap_manager_->FindInSwapInBlackList(device_address->ptr_)) {
  277. mem_swap_manager_->AddMemSwapTask(SwapKind::kHostToDevice, device_address, host_address);
  278. }
  279. }
  280. }
  281. }
  282. return true;
  283. }
  284. bool GPUKernelRuntime::AttemptMallocMem(const DeviceAddressPtr &device_address, size_t size) {
  285. auto ret = mem_manager_->MallocMemFromMemPool(device_address, size);
  286. if (!ret) {
  287. if (!mem_swap_manager_->trigger_swap()) {
  288. return false;
  289. }
  290. mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost);
  291. while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) {
  292. if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) {
  293. device_address_swap_out->set_status(DeviceAddressStatus::kInHost);
  294. mem_manager_->FreeMemFromMemPool(device_address_swap_out);
  295. }
  296. }
  297. ret = mem_manager_->MallocMemFromMemPool(device_address, size);
  298. if (!ret) {
  299. return false;
  300. }
  301. }
  302. return true;
  303. }
  304. void *GPUKernelRuntime::AttemptMallocMem(size_t size) {
  305. auto device_ptr = mem_manager_->MallocMemFromMemPool(size);
  306. if (!device_ptr) {
  307. if (!mem_swap_manager_->trigger_swap()) {
  308. return nullptr;
  309. }
  310. mem_swap_manager_->SyncMemCopyStream(SwapKind::kDeviceToHost);
  311. while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) {
  312. if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) {
  313. device_address_swap_out->set_status(DeviceAddressStatus::kInHost);
  314. mem_manager_->FreeMemFromMemPool(device_address_swap_out);
  315. }
  316. }
  317. device_ptr = mem_manager_->MallocMemFromMemPool(size);
  318. if (!device_ptr) {
  319. return nullptr;
  320. }
  321. }
  322. return device_ptr;
  323. }
  324. bool GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod &kernel_mod,
  325. const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs,
  326. AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs) {
  327. if (!AllocKernelInputDynamicRes(kernel, kernel_inputs)) {
  328. return false;
  329. }
  330. if (!AllocKernelOutputDynamicRes(kernel_mod, kernel, kernel_outputs)) {
  331. return false;
  332. }
  333. if (!AllocKernelWorkspaceDynamicRes(kernel_mod, kernel, kernel_workspaces)) {
  334. return false;
  335. }
  336. return true;
  337. }
  338. bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &kernel, AddressPtrList *kernel_inputs) {
  339. MS_EXCEPTION_IF_NULL(kernel);
  340. MS_EXCEPTION_IF_NULL(kernel_inputs);
  341. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
  342. auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
  343. MS_EXCEPTION_IF_NULL(device_address);
  344. if (mem_swap_manager_->trigger_swap()) {
  345. while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) {
  346. device_address_swap_in->set_status(DeviceAddressStatus::kInDevice);
  347. }
  348. auto status = device_address->status();
  349. switch (status) {
  350. case DeviceAddressStatus::kInDevice:
  351. break;
  352. case DeviceAddressStatus::kInHost:
  353. break;
  354. case DeviceAddressStatus::kInDeviceToHost: {
  355. mem_swap_manager_->InsertSwapInBlackList(device_address->ptr_);
  356. device_address->set_status(DeviceAddressStatus::kInDevice);
  357. break;
  358. }
  359. case DeviceAddressStatus::kInHostToDevice: {
  360. while (device_address->status() != DeviceAddressStatus::kInDevice) {
  361. while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) {
  362. device_address_swap_in->set_status(DeviceAddressStatus::kInDevice);
  363. }
  364. }
  365. break;
  366. }
  367. default:
  368. MS_LOG(ERROR) << "Invaild device address status";
  369. return false;
  370. }
  371. }
  372. MS_EXCEPTION_IF_NULL(device_address->ptr_);
  373. kernel::AddressPtr input = std::make_shared<kernel::Address>();
  374. MS_EXCEPTION_IF_NULL(input);
  375. input->addr = device_address->ptr_;
  376. input->size = device_address->size_;
  377. kernel_inputs->emplace_back(input);
  378. }
  379. return true;
  380. }
  381. bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::KernelMod &kernel_mod,
  382. const mindspore::AnfNodePtr &kernel,
  383. AddressPtrList *kernel_outputs) {
  384. MS_EXCEPTION_IF_NULL(kernel);
  385. MS_EXCEPTION_IF_NULL(kernel_outputs);
  386. MS_EXCEPTION_IF_NULL(mem_manager_);
  387. if (mem_swap_manager_->trigger_swap()) {
  388. while (auto device_address_swap_out = mem_swap_manager_->UpdateSwapQueue(SwapKind::kDeviceToHost)) {
  389. if (!mem_swap_manager_->FindInSwapInBlackList(device_address_swap_out->ptr_) && device_address_swap_out->ptr_) {
  390. device_address_swap_out->set_status(DeviceAddressStatus::kInHost);
  391. mem_manager_->FreeMemFromMemPool(device_address_swap_out);
  392. }
  393. }
  394. }
  395. auto output_sizes = kernel_mod.GetOutputSizeList();
  396. for (size_t i = 0; i < output_sizes.size(); ++i) {
  397. auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i);
  398. MS_EXCEPTION_IF_NULL(device_address);
  399. if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, output_sizes[i])) {
  400. return false;
  401. }
  402. kernel::AddressPtr output = std::make_shared<kernel::Address>();
  403. MS_EXCEPTION_IF_NULL(output);
  404. output->addr = device_address->ptr_;
  405. output->size = output_sizes[i];
  406. kernel_outputs->emplace_back(output);
  407. }
  408. return true;
  409. }
  410. bool GPUKernelRuntime::AllocKernelWorkspaceDynamicRes(const mindspore::kernel::KernelMod &kernel_mod,
  411. const mindspore::AnfNodePtr &kernel,
  412. AddressPtrList *kernel_workspaces) {
  413. MS_EXCEPTION_IF_NULL(kernel);
  414. MS_EXCEPTION_IF_NULL(kernel_workspaces);
  415. MS_EXCEPTION_IF_NULL(mem_manager_);
  416. auto workspace_sizes = kernel_mod.GetWorkspaceSizeList();
  417. for (size_t i = 0; i < workspace_sizes.size(); ++i) {
  418. if (workspace_sizes[i] == 0) {
  419. kernel_workspaces->emplace_back(nullptr);
  420. continue;
  421. }
  422. auto device_ptr = AttemptMallocMem(workspace_sizes[i]);
  423. if (!device_ptr) {
  424. return false;
  425. }
  426. kernel::AddressPtr workspace = std::make_shared<kernel::Address>();
  427. MS_EXCEPTION_IF_NULL(workspace);
  428. workspace->addr = device_ptr;
  429. workspace->size = workspace_sizes[i];
  430. kernel_workspaces->emplace_back(workspace);
  431. }
  432. return true;
  433. }
  434. void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph *graph) {
  435. MS_EXCEPTION_IF_NULL(graph);
  436. auto &kernels = graph->execution_order();
  437. for (auto &kernel : kernels) {
  438. MS_EXCEPTION_IF_NULL(kernel);
  439. if (AnfAlgo::IsCommunicationOp(kernel)) {
  440. AllocCommunicationOpInputDynamicRes(kernel);
  441. AllocCommunicationOpOutputDynamicRes(kernel);
  442. }
  443. }
  444. }
  445. void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel) {
  446. MS_EXCEPTION_IF_NULL(kernel);
  447. MS_EXCEPTION_IF_NULL(mem_manager_);
  448. bool is_need_alloc_memory = false;
  449. bool is_need_free_memory = false;
  450. size_t total_size = 0;
  451. std::vector<size_t> size_list;
  452. DeviceAddressPtrList addr_list;
  453. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
  454. auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
  455. MS_EXCEPTION_IF_NULL(device_address);
  456. if (device_address->ptr_ == nullptr) {
  457. is_need_alloc_memory = true;
  458. } else {
  459. is_need_free_memory = true;
  460. }
  461. total_size += device_address->size_;
  462. size_list.emplace_back(device_address->size_);
  463. addr_list.emplace_back(device_address);
  464. }
  465. AllocCommunicationOpMemory(is_need_alloc_memory, is_need_free_memory, addr_list, total_size, size_list);
  466. }
  467. void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel) {
  468. MS_EXCEPTION_IF_NULL(kernel);
  469. MS_EXCEPTION_IF_NULL(mem_manager_);
  470. bool is_need_alloc_memory = false;
  471. bool is_need_free_memory = false;
  472. size_t total_size = 0;
  473. std::vector<size_t> size_list;
  474. DeviceAddressPtrList addr_list;
  475. auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
  476. MS_EXCEPTION_IF_NULL(kernel_mod);
  477. auto output_sizes = kernel_mod->GetOutputSizeList();
  478. for (size_t i = 0; i < output_sizes.size(); ++i) {
  479. auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i);
  480. MS_EXCEPTION_IF_NULL(device_address);
  481. if (device_address->ptr_ == nullptr) {
  482. is_need_alloc_memory = true;
  483. } else {
  484. is_need_free_memory = true;
  485. }
  486. total_size += output_sizes[i];
  487. size_list.emplace_back(output_sizes[i]);
  488. addr_list.emplace_back(device_address);
  489. }
  490. AllocCommunicationOpMemory(is_need_alloc_memory, is_need_free_memory, addr_list, total_size, size_list);
  491. }
  492. void GPUKernelRuntime::AllocCommunicationOpMemory(bool is_need_alloc_memory, bool is_need_free_memory,
  493. const DeviceAddressPtrList addr_list, size_t total_size,
  494. std::vector<size_t> size_list) {
  495. if (!is_need_alloc_memory) {
  496. return;
  497. }
  498. if (is_need_free_memory) {
  499. for (const auto &iter : addr_list) {
  500. MS_EXCEPTION_IF_NULL(iter);
  501. // Free the inputs/outputs of communication kernel which are not released.
  502. if (iter->ptr_ != nullptr) {
  503. mem_manager_->FreeMemFromMemPool(iter);
  504. }
  505. }
  506. }
  507. auto ret = mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list);
  508. if (!ret) {
  509. MS_LOG(EXCEPTION) << "Malloc device memory failed.";
  510. }
  511. }
  512. void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel,
  513. const AddressPtrList &kernel_workspaces, uint32_t graph_id) {
  514. MS_EXCEPTION_IF_NULL(kernel);
  515. MS_EXCEPTION_IF_NULL(mem_manager_);
  516. auto mem_reuse_util_ptr = mem_reuse_util_map_[graph_id];
  517. MS_EXCEPTION_IF_NULL(mem_reuse_util_ptr);
  518. auto cnode = kernel->cast<CNodePtr>();
  519. MS_EXCEPTION_IF_NULL(cnode);
  520. if (AnfAlgo::GetCNodeName(kernel) == kAllReduceOpName) {
  521. return;
  522. }
  523. // Free the input of kernel by reference count.
  524. for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
  525. auto kernel_ref_count_ptr = mem_reuse_util_ptr->GetKernelInputRef(cnode, i);
  526. if (kernel_ref_count_ptr == nullptr) {
  527. continue;
  528. }
  529. kernel_ref_count_ptr->ref_count_dynamic_use_--;
  530. if (kernel_ref_count_ptr->ref_count_dynamic_use_ < 0) {
  531. MS_LOG(EXCEPTION) << "Check dynamic reference count failed.";
  532. }
  533. if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) {
  534. auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i);
  535. mem_manager_->FreeMemFromMemPool(device_address);
  536. device_address->set_status(DeviceAddressStatus::kInDevice);
  537. }
  538. }
  539. // Free the output of kernel, if output has no reference.
  540. for (size_t i = 0; i < AnfAlgo::GetOutputTensorNum(kernel); ++i) {
  541. auto kernel_ref_count_ptr = mem_reuse_util_ptr->GetRef(cnode, i);
  542. if (kernel_ref_count_ptr == nullptr) {
  543. continue;
  544. }
  545. if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) {
  546. auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i);
  547. mem_manager_->FreeMemFromMemPool(device_address);
  548. device_address->set_status(DeviceAddressStatus::kInDevice);
  549. }
  550. }
  551. // Free the workspace of kernel.
  552. for (size_t i = 0; i < kernel_workspaces.size(); ++i) {
  553. auto workspace = kernel_workspaces[i];
  554. if (workspace != nullptr) {
  555. MS_EXCEPTION_IF_NULL(workspace->addr);
  556. mem_manager_->FreeMemFromMemPool(workspace->addr);
  557. workspace->addr = nullptr;
  558. }
  559. }
  560. }
  561. } // namespace gpu
  562. } // namespace device
  563. } // namespace mindspore