You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tbe_kernel_mod.cc 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/tbe/tbe_kernel_mod.h"
  17. #include <algorithm>
  18. #include "runtime/rt.h"
  19. #include "nlohmann/json.hpp"
  20. #include "graphengine/inc/framework/ge_runtime/task_info.h"
  21. namespace mindspore {
  22. namespace kernel {
  23. using TbeTaskInfoPtr = std::shared_ptr<ge::model_runner::TbeTaskInfo>;
  24. using tbe::KernelManager;
  25. bool TbeKernelMod::Launch(const std::vector<mindspore::kernel::AddressPtr> &inputs,
  26. const std::vector<mindspore::kernel::AddressPtr> &workspace,
  27. const std::vector<mindspore::kernel::AddressPtr> &outputs, void *stream_ptr) {
  28. if (stream_ptr == nullptr) {
  29. MS_LOG(ERROR) << "stream_ptr should not be nullptr.";
  30. return false;
  31. }
  32. if (kernel_pack_ == nullptr) {
  33. MS_LOG(ERROR) << "kernel pack should not be nullptr.";
  34. return false;
  35. }
  36. uint32_t blockdim = 1; // default blockdim equal to 1.
  37. auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &blockdim);
  38. if (func_stub == 0) {
  39. MS_LOG(ERROR) << "GenFuncStub failed.";
  40. return false;
  41. }
  42. // pack all addresses into a vector.
  43. std::vector<void *> runtimeargs;
  44. (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtimeargs),
  45. [](const AddressPtr &input) -> void * { return input->addr; });
  46. (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtimeargs),
  47. [](const AddressPtr &output) -> void * { return output->addr; });
  48. if (!workspace.empty()) {
  49. (void)std::transform(std::begin(workspace), std::end(workspace), std::back_inserter(runtimeargs),
  50. [](const AddressPtr &addr) -> void * { return addr->addr; });
  51. }
  52. rtL2Ctrl_t *l2ctrl = nullptr;
  53. const void *stubFunc = reinterpret_cast<void *>(func_stub);
  54. auto argsSize = static_cast<uint32_t>(UlongToUint(sizeof(void *)) * runtimeargs.size());
  55. if (RT_ERROR_NONE != rtKernelLaunch(stubFunc, blockdim, runtimeargs.data(), argsSize, l2ctrl, stream_ptr)) {
  56. MS_LOG(ERROR) << "Call runtime rtKernelLaunch error.";
  57. return false;
  58. }
  59. return true;
  60. }
  61. std::vector<TaskInfoPtr> TbeKernelMod::GenTask(const std::vector<AddressPtr> &inputs,
  62. const std::vector<AddressPtr> &workspaces,
  63. const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
  64. if (kernel_pack_ == nullptr) {
  65. MS_EXCEPTION(ArgumentError) << "kernel pack should not be nullptr.";
  66. }
  67. std::vector<uint8_t> args;
  68. std::vector<uint8_t> sm_desc;
  69. std::vector<uint8_t> meta_data;
  70. std::vector<void *> input_data_addrs;
  71. std::vector<void *> output_data_addrs;
  72. std::vector<void *> workspace_addrs;
  73. // pack all addresses into a vector.
  74. (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(input_data_addrs),
  75. [](const AddressPtr &input) -> void * { return input->addr; });
  76. (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_data_addrs),
  77. [](const AddressPtr &output) -> void * { return output->addr; });
  78. if (!workspaces.empty()) {
  79. (void)std::transform(std::begin(workspaces), std::end(workspaces), std::back_inserter(workspace_addrs),
  80. [](const AddressPtr &workspace) -> void * { return workspace->addr; });
  81. }
  82. stream_id_ = stream_id;
  83. auto funcstub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim_);
  84. if (funcstub == 0) {
  85. MS_EXCEPTION(ArgumentError) << "GenFuncStub failed.";
  86. }
  87. std::string stub_func = KernelManager::GetStubFuncName(kernel_pack_);
  88. MS_LOG(INFO) << "block_dim is:" << block_dim_;
  89. TbeTaskInfoPtr task_info_ptr =
  90. make_shared<ge::model_runner::TbeTaskInfo>(stream_id, stub_func, block_dim_, args, 0, sm_desc, nullptr, 0,
  91. meta_data, input_data_addrs, output_data_addrs, workspace_addrs);
  92. return {task_info_ptr};
  93. }
  94. vector<size_t> TbeKernelMod::GenParameters() {
  95. auto kernel_json_info = kernel_pack_->kernel_json_info();
  96. return kernel_json_info.parameters;
  97. }
  98. } // namespace kernel
  99. } // namespace mindspore