You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

launch_kernel.cc 4.3 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/device/launch_kernel.h"
  17. namespace mindspore::device {
  18. std::vector<kernel::AddressPtr> LaunchKernel::ObtainKernelAddress(const std::vector<size_t> &list,
  19. std::vector<uint8_t *> *addr) {
  20. MS_EXCEPTION_IF_NULL(addr);
  21. std::vector<kernel::AddressPtr> kernel_address;
  22. if (addr->size() < list.size()) {
  23. MS_LOG_EXCEPTION << "Error addr size!";
  24. }
  25. for (size_t i = 0; i < list.size(); ++i) {
  26. auto size = AlignSizeForLaunchKernel(list[i]);
  27. (*addr)[i] = AllocDeviceMem(size);
  28. auto address = std::make_shared<kernel::Address>();
  29. MS_EXCEPTION_IF_NULL(address);
  30. address->addr = (*addr)[i];
  31. MS_EXCEPTION_IF_NULL(address->addr);
  32. address->size = size;
  33. kernel_address.push_back(address);
  34. }
  35. return kernel_address;
  36. }
  37. std::vector<kernel::AddressPtr> LaunchKernel::ObtainKernelInputs(const std::vector<size_t> &inputs_list,
  38. const std::vector<uint8_t *> &inputs_addr) {
  39. std::vector<kernel::AddressPtr> kernel_inputs;
  40. if (inputs_list.size() != inputs_addr.size()) {
  41. MS_LOG(ERROR) << "input_list size should equal to input_addr_ size, input_list size: " << inputs_list.size()
  42. << ", input_addr_ size: " << inputs_addr.size();
  43. }
  44. for (size_t i = 0; i < inputs_list.size(); ++i) {
  45. auto input_size = AlignSizeForLaunchKernel(inputs_list[i]);
  46. auto input = std::make_shared<kernel::Address>();
  47. MS_EXCEPTION_IF_NULL(input);
  48. input->addr = inputs_addr[i];
  49. MS_EXCEPTION_IF_NULL(input->addr);
  50. input->size = input_size;
  51. kernel_inputs.push_back(input);
  52. }
  53. return kernel_inputs;
  54. }
  55. std::vector<kernel::AddressPtr> LaunchKernel::ObtainKernelOutputs(const std::vector<size_t> &outputs_list) {
  56. // init output_addr_
  57. outputs_addr_ = std::vector<uint8_t *>(outputs_list.size(), nullptr);
  58. auto kernel_outputs = ObtainKernelAddress(outputs_list, &outputs_addr_);
  59. return kernel_outputs;
  60. }
  61. std::vector<kernel::AddressPtr> LaunchKernel::ObtainKernelWorkspaces(const std::vector<size_t> &workspaces_list) {
  62. std::vector<kernel::AddressPtr> kernel_workspace;
  63. if (workspaces_list.empty()) {
  64. return kernel_workspace;
  65. }
  66. // init workspace_addr_
  67. workspaces_addr_ = std::vector<uint8_t *>(workspaces_list.size(), nullptr);
  68. kernel_workspace = ObtainKernelAddress(workspaces_list, &workspaces_addr_);
  69. return kernel_workspace;
  70. }
  71. void LaunchKernel::LaunchSingleKernel(const std::vector<uint8_t *> &inputs_addr) {
  72. MS_EXCEPTION_IF_NULL(kernel_mod_);
  73. // obtain kernel inputs
  74. auto kernel_inputs = ObtainKernelInputs(kernel_mod_->GetInputSizeList(), inputs_addr);
  75. // obtain kernel outputs
  76. auto kernel_outputs = ObtainKernelOutputs(kernel_mod_->GetOutputSizeList());
  77. // obtain kernel workspace
  78. auto kernel_workspaces = ObtainKernelWorkspaces(kernel_mod_->GetWorkspaceSizeList());
  79. // launch
  80. auto ret_status = kernel_mod_->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
  81. if (!ret_status) {
  82. MS_LOG(ERROR) << "Launch single kernel failed.";
  83. }
  84. }
  85. void LaunchKernel::FreeOutputAndWorkspaceDeviceMem() {
  86. // free outputs_addr and workspaces_addr_
  87. for (size_t i = 0; i < outputs_addr_.size(); ++i) {
  88. if (outputs_addr_[i] != nullptr) {
  89. FreeDeviceMem(outputs_addr_[i]);
  90. outputs_addr_[i] = nullptr;
  91. }
  92. }
  93. for (size_t i = 0; i < workspaces_addr_.size(); ++i) {
  94. if (workspaces_addr_[i] != nullptr) {
  95. FreeDeviceMem(workspaces_addr_[i]);
  96. workspaces_addr_[i] = nullptr;
  97. }
  98. }
  99. outputs_addr_.clear();
  100. workspaces_addr_.clear();
  101. }
  102. } // namespace mindspore::device