You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

memcpy_async.cc 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/rts/memcpy_async.h"
  17. #include <memory>
  18. #include <string>
  19. #include "runtime/mem.h"
  20. #include "common/utils.h"
  21. #include "session/anf_runtime_algorithm.h"
  22. #include "common/trans.h"
  23. using ge::model_runner::MemcpyAsyncTaskInfo;
  24. using MemcpyAsyncTaskInfoPtr = std::shared_ptr<MemcpyAsyncTaskInfo>;
  25. namespace mindspore {
  26. namespace kernel {
  27. MemCpyAsyncKernel::MemCpyAsyncKernel() {}
  28. MemCpyAsyncKernel::~MemCpyAsyncKernel() {}
  29. bool MemCpyAsyncKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> & /*workspace*/,
  30. const std::vector<AddressPtr> &outputs, void *stream_ptr) {
  31. if (inputs.size() != 1) {
  32. MS_LOG(ERROR) << "inputs size is not one";
  33. return false;
  34. }
  35. if (outputs.size() != 1) {
  36. MS_LOG(ERROR) << "outputs size is not one";
  37. return false;
  38. }
  39. if (inputs[0]->addr == outputs[0]->addr) {
  40. MS_LOG(INFO) << "input addr is same with output addr , no need exe memcpy async";
  41. return true;
  42. }
  43. if (outputs[0]->size < inputs[0]->size) {
  44. MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size";
  45. }
  46. // input x -> memcpy_async -> AllReduce
  47. if (outputs[0]->size > inputs[0]->size) {
  48. MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size";
  49. }
  50. rtError_t status = rtMemcpyAsync(outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size,
  51. RT_MEMCPY_DEVICE_TO_DEVICE, stream_ptr);
  52. if (status != RT_ERROR_NONE) {
  53. MS_LOG(ERROR) << "MemCpyAsync op rtMemcpyAsync failed!";
  54. return false;
  55. }
  56. return true;
  57. }
  58. bool MemCpyAsyncKernel::Init(const mindspore::AnfNodePtr &anf_node) {
  59. MS_EXCEPTION_IF_NULL(anf_node);
  60. GetInputOutputDataType(anf_node);
  61. GetInputOutputTotalCount(anf_node);
  62. return true;
  63. }
  64. void MemCpyAsyncKernel::GetInputOutputDataType(const AnfNodePtr &anf_node) {
  65. MS_EXCEPTION_IF_NULL(anf_node);
  66. size_t input_size = AnfAlgo::GetInputTensorNum(anf_node);
  67. if (input_size != 1) {
  68. MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1";
  69. }
  70. input_type_id_ = AnfAlgo::GetPrevNodeOutputDeviceDataType(anf_node, 0);
  71. }
  72. void MemCpyAsyncKernel::GetInputOutputTotalCount(const AnfNodePtr &anf_node) {
  73. MS_EXCEPTION_IF_NULL(anf_node);
  74. size_t input_size = AnfAlgo::GetInputTensorNum(anf_node);
  75. if (input_size != 1) {
  76. MS_LOG(EXCEPTION) << "MemCpyAsync input size is not 1";
  77. }
  78. size_t type_size = trans::TypeIdSize(input_type_id_);
  79. std::vector<size_t> shape_i = AnfAlgo::GetInputDeviceShape(anf_node, 0);
  80. size_t total_size = 1;
  81. for (size_t i = 0; i < shape_i.size(); i++) {
  82. total_size = total_size * shape_i[i];
  83. }
  84. total_size *= type_size;
  85. MS_LOG(INFO) << "MemCpyAsync size[" << total_size << "]";
  86. input_size_list_.emplace_back(total_size);
  87. output_size_list_.emplace_back(total_size);
  88. }
  89. std::vector<TaskInfoPtr> MemCpyAsyncKernel::GenTask(const std::vector<AddressPtr> &inputs,
  90. const std::vector<AddressPtr> &,
  91. const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
  92. if (inputs.size() != 1) {
  93. MS_LOG(EXCEPTION) << "MemCpyAsync op inputs is not one";
  94. }
  95. if (outputs.size() != 1) {
  96. MS_LOG(EXCEPTION) << "MemCpyAsync op output is not one";
  97. }
  98. if (outputs[0]->size < inputs[0]->size) {
  99. MS_LOG(EXCEPTION) << "rtMemcpyAsync destMax < src size";
  100. }
  101. // input x -> memcpy_async -> AllReduce
  102. if (outputs[0]->size > inputs[0]->size) {
  103. MS_LOG(WARNING) << "rtMemcpyAsync destMax > src size";
  104. }
  105. stream_id_ = stream_id;
  106. std::shared_ptr<MemcpyAsyncTaskInfo> task_info_ptr = std::make_shared<MemcpyAsyncTaskInfo>(
  107. stream_id, outputs[0]->addr, outputs[0]->size, inputs[0]->addr, inputs[0]->size, RT_MEMCPY_DEVICE_TO_DEVICE);
  108. MS_EXCEPTION_IF_NULL(task_info_ptr);
  109. return {task_info_ptr};
  110. }
  111. const std::vector<TypeId> data_type_list{kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32,
  112. kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16,
  113. kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat, kNumberTypeFloat16,
  114. kNumberTypeFloat32, kNumberTypeFloat64, kNumberTypeBool};
  115. const std::vector<std::string> format_list = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC,
  116. kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0,
  117. kOpFormat_C1HWNCoC0};
  118. MemCpyAsyncDesc::MemCpyAsyncDesc() {}
  119. MemCpyAsyncDesc::~MemCpyAsyncDesc() {}
  120. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> MemCpyAsyncDesc::GetKernelInfo() {
  121. std::vector<std::shared_ptr<kernel::KernelBuildInfo>> memcpy_build_info{};
  122. for (const auto &format : format_list) {
  123. for (const auto &type : data_type_list) {
  124. auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
  125. vector<string> input_format{format};
  126. vector<TypeId> input_type{type};
  127. vector<string> output_format{format};
  128. vector<TypeId> output_type{type};
  129. builder.SetInputsFormat(input_format);
  130. builder.SetInputsDeviceType(input_type);
  131. builder.SetOutputsFormat(output_format);
  132. builder.SetOutputsDeviceType(output_type);
  133. builder.SetProcessor(AICORE);
  134. builder.SetKernelType(RT_KERNEL);
  135. builder.SetFusionType(OPAQUE);
  136. memcpy_build_info.emplace_back(builder.Build());
  137. }
  138. }
  139. return memcpy_build_info;
  140. }
  141. } // namespace kernel
  142. } // namespace mindspore