You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tbe_kernel_mod.cc 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/tbe/tbe_kernel_mod.h"
  17. #include <algorithm>
  18. #include "runtime/rt.h"
  19. #include "nlohmann/json.hpp"
  20. #include "graphengine/inc/framework/ge_runtime/task_info.h"
  21. namespace mindspore {
  22. namespace kernel {
  23. using TbeTaskInfoPtr = std::shared_ptr<ge::model_runner::TbeTaskInfo>;
  24. using tbe::KernelManager;
  25. bool TbeKernelMod::Launch(const std::vector<mindspore::kernel::AddressPtr> &inputs,
  26. const std::vector<mindspore::kernel::AddressPtr> &workspace,
  27. const std::vector<mindspore::kernel::AddressPtr> &outputs, uintptr_t stream_ptr) {
  28. if (stream_ptr == 0) {
  29. MS_LOG(ERROR) << "stream_ptr should not be nullptr.";
  30. return false;
  31. }
  32. if (kernel_pack_ == nullptr) {
  33. MS_LOG(ERROR) << "kernel pack should not be nullptr.";
  34. return false;
  35. }
  36. uint32_t blockdim = 1; // default blockdim equal to 1.
  37. auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &blockdim);
  38. if (func_stub == 0) {
  39. MS_LOG(ERROR) << "GenFuncStub failed.";
  40. return false;
  41. }
  42. // pack all addresses into a vector.
  43. std::vector<void *> runtimeargs;
  44. (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtimeargs),
  45. [](const AddressPtr &input) -> void * { return input->addr; });
  46. (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtimeargs),
  47. [](const AddressPtr &output) -> void * { return output->addr; });
  48. if (!workspace.empty()) {
  49. (void)std::transform(std::begin(workspace), std::end(workspace), std::back_inserter(runtimeargs),
  50. [](const AddressPtr &addr) -> void * { return addr->addr; });
  51. }
  52. rtL2Ctrl_t *l2ctrl = nullptr;
  53. auto *stream = reinterpret_cast<rtStream_t *>(stream_ptr);
  54. const void *stubFunc = reinterpret_cast<void *>(func_stub);
  55. auto argsSize = static_cast<uint32_t>(UlongToUint(sizeof(void *)) * runtimeargs.size());
  56. if (RT_ERROR_NONE != rtKernelLaunch(stubFunc, blockdim, runtimeargs.data(), argsSize, l2ctrl, stream)) {
  57. MS_LOG(ERROR) << "Call runtime rtKernelLaunch error.";
  58. return false;
  59. }
  60. return true;
  61. }
  62. vector<TaskInfoPtr> TbeKernelMod::GenTask(const std::vector<AddressPtr> &inputs,
  63. const std::vector<AddressPtr> &workspaces,
  64. const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
  65. if (kernel_pack_ == nullptr) {
  66. MS_EXCEPTION(ArgumentError) << "kernel pack should not be nullptr.";
  67. }
  68. std::vector<uint8_t> args;
  69. std::vector<uint8_t> sm_desc;
  70. std::vector<uint8_t> meta_data;
  71. std::vector<void *> input_data_addrs;
  72. std::vector<void *> output_data_addrs;
  73. std::vector<void *> workspace_addrs;
  74. // pack all addresses into a vector.
  75. (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs),
  76. [](const AddressPtr &input) -> void * { return input->addr; });
  77. (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs),
  78. [](const AddressPtr &output) -> void * { return output->addr; });
  79. if (!workspaces.empty()) {
  80. (void)std::transform(std::begin(workspaces), std::end(workspaces), std::back_inserter(workspace_addrs),
  81. [](const AddressPtr &workspace) -> void * { return workspace->addr; });
  82. }
  83. uint32_t block_dim = 1; // default blockdim equal to 1.
  84. auto funcstub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim);
  85. if (funcstub == 0) {
  86. MS_EXCEPTION(ArgumentError) << "GenFuncStub failed.";
  87. }
  88. std::string stub_func = KernelManager::GetStubFuncName(kernel_pack_);
  89. MS_LOG(INFO) << "block_dim is:" << block_dim;
  90. TbeTaskInfoPtr task_info_ptr =
  91. make_shared<ge::model_runner::TbeTaskInfo>(stream_id, stub_func, block_dim, args, 0, sm_desc, nullptr, 0, meta_data,
  92. input_data_addrs, output_data_addrs, workspace_addrs);
  93. return {task_info_ptr};
  94. }
  95. vector<size_t> TbeKernelMod::GenParameters() {
  96. auto kernel_json_info = kernel_pack_->kernel_json_info();
  97. return kernel_json_info.parameters;
  98. }
  99. } // namespace kernel
  100. } // namespace mindspore