You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

memcpy_async.cc 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/mng/memcpy_async.h"
  17. #include <memory>
  18. #include <string>
  19. #include "runtime/mem.h"
  20. #include "common/utils.h"
  21. #include "session/anf_runtime_algorithm.h"
  22. #include "common/trans.h"
  23. using ge::model_runner::MemcpyAsyncTaskInfo;
  24. using MemcpyAsyncTaskInfoPtr = std::shared_ptr<MemcpyAsyncTaskInfo>;
  25. namespace mindspore {
  26. namespace kernel {
  27. MemCpyAsyncKernel::MemCpyAsyncKernel() {}
  28. MemCpyAsyncKernel::~MemCpyAsyncKernel() {}
  29. bool MemCpyAsyncKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> & /*workspace*/,
  30. const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
  31. auto stream = reinterpret_cast<rtStream_t>(stream_ptr);
  32. if (inputs.size() != 1) {
  33. MS_LOG(ERROR) << "inputs size is not one";
  34. return false;
  35. }
  36. if (outputs.size() != 1) {
  37. MS_LOG(ERROR) << "outputs size is not one";
  38. return false;
  39. }
  40. if (inputs[0]->addr == outputs[0]->addr) {
  41. MS_LOG(INFO) << "input addr is same with output addr , no need exe memcpy async";
  42. return true;
  43. }
  44. rtError_t status = rtMemcpyAsync(outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size,
  45. RT_MEMCPY_DEVICE_TO_DEVICE, stream);
  46. if (status != RT_ERROR_NONE) {
  47. MS_LOG(ERROR) << "MemCpyAsync op rtMemcpyAsync failed!";
  48. return false;
  49. }
  50. return true;
  51. }
  52. bool MemCpyAsyncKernel::Init(const mindspore::AnfNodePtr &anf_node) {
  53. MS_EXCEPTION_IF_NULL(anf_node);
  54. GetInputOutputDataType(anf_node);
  55. GetInputOutputTotalCount(anf_node);
  56. return true;
  57. }
  58. void MemCpyAsyncKernel::GetInputOutputDataType(const AnfNodePtr &anf_node) {
  59. MS_EXCEPTION_IF_NULL(anf_node);
  60. size_t input_size = AnfAlgo::GetInputTensorNum(anf_node);
  61. if (input_size != 1) {
  62. MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1";
  63. }
  64. input_type_id_ = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, 0);
  65. }
  66. void MemCpyAsyncKernel::GetInputOutputTotalCount(const AnfNodePtr &anf_node) {
  67. MS_EXCEPTION_IF_NULL(anf_node);
  68. size_t input_size = AnfAlgo::GetInputTensorNum(anf_node);
  69. if (input_size != 1) {
  70. MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1";
  71. }
  72. size_t type_size = trans::TypeIdSize(input_type_id_);
  73. std::vector<size_t> shape_i = AnfAlgo::GetInputDeviceShape(anf_node, 0);
  74. size_t total_size = 1;
  75. for (size_t i = 0; i < shape_i.size(); i++) {
  76. total_size = total_size * shape_i[i];
  77. }
  78. total_size *= type_size;
  79. MS_LOG(INFO) << "MemCpyAsync size[" << total_size << "]";
  80. input_size_list_.emplace_back(total_size);
  81. output_size_list_.emplace_back(total_size);
  82. }
  83. std::vector<TaskInfoPtr> MemCpyAsyncKernel::GenTask(const vector<mindspore::kernel::AddressPtr> &inputs,
  84. const vector<mindspore::kernel::AddressPtr> & /*workspace*/,
  85. const vector<mindspore::kernel::AddressPtr> &outputs,
  86. uint32_t stream_id) {
  87. if (inputs.size() != 1) {
  88. MS_LOG(EXCEPTION) << "MemCpyAsync op inputs is not one";
  89. }
  90. if (outputs.size() != 1) {
  91. MS_LOG(EXCEPTION) << "MemCpyAsync op output is not one";
  92. }
  93. std::shared_ptr<MemcpyAsyncTaskInfo> task_info_ptr = std::make_shared<MemcpyAsyncTaskInfo>(
  94. stream_id, outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size, RT_MEMCPY_DEVICE_TO_DEVICE);
  95. MS_EXCEPTION_IF_NULL(task_info_ptr);
  96. return {task_info_ptr};
  97. }
  98. const std::vector<TypeId> data_type_list{kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32,
  99. kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16,
  100. kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat, kNumberTypeFloat16,
  101. kNumberTypeFloat32, kNumberTypeFloat64};
  102. const std::vector<std::string> format_list = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC,
  103. kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0,
  104. kOpFormat_C1HWNCoC0};
  105. MemCpyAsyncDesc::MemCpyAsyncDesc() {}
  106. MemCpyAsyncDesc::~MemCpyAsyncDesc() {}
  107. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> MemCpyAsyncDesc::GetKernelInfo() {
  108. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> memcpy_build_info{};
  109. for (const auto &format : format_list) {
  110. for (const auto &type : data_type_list) {
  111. auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
  112. vector<string> input_format{format};
  113. vector<TypeId> input_type{type};
  114. vector<string> output_format{format};
  115. vector<TypeId> output_type{type};
  116. builder.SetInputsFormat(input_format);
  117. builder.SetInputsDeviceType(input_type);
  118. builder.SetOutputsFormat(output_format);
  119. builder.SetOutputsDeviceType(output_type);
  120. builder.SetProcessor(AICORE);
  121. builder.SetKernelType(RT_KERNEL);
  122. builder.SetFusionType(OPAQUE);
  123. memcpy_build_info.emplace_back(builder.Build());
  124. }
  125. }
  126. return memcpy_build_info;
  127. }
  128. } // namespace kernel
  129. } // namespace mindspore