You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

launch_kernel.cc 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "runtime/device/launch_kernel.h"
  17. #include <vector>
  18. #include <memory>
  19. namespace mindspore::device {
  20. std::vector<kernel::AddressPtr> LaunchKernel::ObtainKernelAddress(const std::vector<size_t> &list,
  21. std::vector<uint8_t *> *addr) {
  22. std::vector<kernel::AddressPtr> kernel_address;
  23. for (size_t i = 0; i < list.size(); ++i) {
  24. auto size = AlignSizeForLaunchKernel(list[i]);
  25. (*addr)[i] = AllocDeviceMem(size);
  26. auto address = std::make_shared<kernel::Address>();
  27. MS_EXCEPTION_IF_NULL(address);
  28. address->addr = (*addr)[i];
  29. MS_EXCEPTION_IF_NULL(address->addr);
  30. address->size = size;
  31. kernel_address.push_back(address);
  32. }
  33. return kernel_address;
  34. }
  35. std::vector<kernel::AddressPtr> LaunchKernel::ObtainKernelInputs(const std::vector<size_t> &inputs_list,
  36. const std::vector<uint8_t *> &inputs_addr) {
  37. std::vector<kernel::AddressPtr> kernel_inputs;
  38. if (inputs_list.size() != inputs_addr.size()) {
  39. MS_LOG(ERROR) << "input_list size should equal to input_addr_ size";
  40. }
  41. for (size_t i = 0; i < inputs_list.size(); ++i) {
  42. auto input_size = AlignSizeForLaunchKernel(inputs_list[i]);
  43. auto input = std::make_shared<kernel::Address>();
  44. MS_EXCEPTION_IF_NULL(input);
  45. input->addr = inputs_addr[i];
  46. MS_EXCEPTION_IF_NULL(input->addr);
  47. input->size = input_size;
  48. kernel_inputs.push_back(input);
  49. }
  50. return kernel_inputs;
  51. }
  52. std::vector<kernel::AddressPtr> LaunchKernel::ObtainKernelOutputs(const std::vector<size_t> &outputs_list) {
  53. // init output_addr_
  54. outputs_addr_ = std::vector<uint8_t *>(outputs_list.size(), nullptr);
  55. auto kernel_outputs = ObtainKernelAddress(outputs_list, &outputs_addr_);
  56. return kernel_outputs;
  57. }
  58. std::vector<kernel::AddressPtr> LaunchKernel::ObtainKernelWorkspaces(const std::vector<size_t> &workspaces_list) {
  59. std::vector<kernel::AddressPtr> kernel_workspace;
  60. if (workspaces_list.empty()) {
  61. return kernel_workspace;
  62. }
  63. // init workspace_addr_
  64. workspaces_addr_ = std::vector<uint8_t *>(workspaces_list.size(), nullptr);
  65. kernel_workspace = ObtainKernelAddress(workspaces_list, &workspaces_addr_);
  66. return kernel_workspace;
  67. }
  68. void LaunchKernel::LaunchSingleKernel(const std::vector<uint8_t *> &inputs_addr) {
  69. MS_EXCEPTION_IF_NULL(kernel_mod_);
  70. // obtain kernel inputs
  71. auto kernel_inputs = ObtainKernelInputs(kernel_mod_->GetInputSizeList(), inputs_addr);
  72. // obtain kernel outputs
  73. auto kernel_outputs = ObtainKernelOutputs(kernel_mod_->GetOutputSizeList());
  74. // obtain kernel workspace
  75. auto kernel_workspaces = ObtainKernelWorkspaces(kernel_mod_->GetWorkspaceSizeList());
  76. // launch
  77. auto ret_status = kernel_mod_->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
  78. if (!ret_status) {
  79. MS_LOG(ERROR) << "Launch mul kernel failed.";
  80. }
  81. }
  82. void LaunchKernel::FreeOutputAndWorkspaceDeviceMem() {
  83. // free outputs_addr and workspaces_addr_
  84. for (size_t i = 0; i < outputs_addr_.size(); ++i) {
  85. if (outputs_addr_[i] != nullptr) {
  86. FreeDeviceMem(outputs_addr_[i]);
  87. outputs_addr_[i] = nullptr;
  88. }
  89. }
  90. for (size_t i = 0; i < workspaces_addr_.size(); ++i) {
  91. if (workspaces_addr_[i] != nullptr) {
  92. FreeDeviceMem(workspaces_addr_[i]);
  93. workspaces_addr_[i] = nullptr;
  94. }
  95. }
  96. outputs_addr_.clear();
  97. workspaces_addr_.clear();
  98. }
  99. } // namespace mindspore::device