You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hccl_kernel.cc 5.6 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/hccl/hccl_kernel.h"
  17. #include "device/ascend/tasksink/runtime_utils.h"
  18. #include "session/anf_runtime_algorithm.h"
  19. #include "utils/utils.h"
  20. using HcclTaskInfoPtr = std::shared_ptr<ge::model_runner::HcclTaskInfo>;
  21. using ge::model_runner::HcclTaskInfo;
  22. using mindspore::device::ascend::tasksink::RuntimeUtils;
  23. namespace mindspore {
  24. namespace kernel {
  25. void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) {
  26. hcclKernelMap_.emplace(name, std::move(fun));
  27. }
  28. std::shared_ptr<HcclKernel> HcclKernelFactory::Get(const std::string &name) {
  29. const auto &map = Get().hcclKernelMap_;
  30. auto it = map.find(name);
  31. if (it != map.end() && it->second) {
  32. return (it->second)();
  33. }
  34. return nullptr;
  35. }
  36. HcclKernelFactory &HcclKernelFactory::Get() {
  37. static HcclKernelFactory _this;
  38. return _this;
  39. }
  40. HcclKernel::HcclKernel() : hccl_count_(0), op_type_(HCCL_REP_OP_SUM), root_id_(0), anf_node_(nullptr) {}
  41. HcclKernel::~HcclKernel() {
  42. hccl_kernel_input_shape_list_.clear();
  43. hccl_kernel_output_shape_list_.clear();
  44. hccl_data_type_list_.clear();
  45. hccl_count_ = 0;
  46. op_type_ = HCCL_REP_OP_SUM;
  47. root_id_ = 0;
  48. input_size_list_.clear();
  49. output_size_list_.clear();
  50. workspace_size_list_.clear();
  51. anf_node_ = nullptr;
  52. }
  53. bool HcclKernel::Init(const AnfNodePtr &anf_node) {
  54. MS_EXCEPTION_IF_NULL(anf_node);
  55. op_name_ = AnfAlgo::GetCNodeName(anf_node);
  56. if (!HcomUtil::GetKernelInputShape(anf_node, &hccl_kernel_input_shape_list_)) {
  57. MS_LOG(ERROR) << "GetKernelInputShape fail!";
  58. return false;
  59. }
  60. if (!HcomUtil::GetKernelOutputShape(anf_node, &hccl_kernel_output_shape_list_)) {
  61. MS_LOG(ERROR) << "GetKernelOutputShape fail!";
  62. return false;
  63. }
  64. if (!HcomUtil::GetHcomDataType(anf_node, &hccl_data_type_list_)) {
  65. MS_LOG(ERROR) << "GetHcomDataType fail!";
  66. return false;
  67. }
  68. if (!HcomUtil::GetHcomCount(anf_node, hccl_data_type_list_, hccl_kernel_input_shape_list_, &hccl_count_)) {
  69. MS_LOG(ERROR) << "GetHcomCount fail!";
  70. return false;
  71. }
  72. if (op_name_ == kAllReduce || op_name_ == kReduceScatter) {
  73. if (!HcomUtil::GetHcomOperationType(anf_node, &op_type_)) {
  74. MS_LOG(ERROR) << "GetHcomOperationType fail!";
  75. return false;
  76. }
  77. }
  78. if (op_name_ == kBroadcast) {
  79. if (!HcomUtil::GetHcomRootId(anf_node, &root_id_)) {
  80. MS_LOG(ERROR) << "GetHcomRootId fail!";
  81. return false;
  82. }
  83. }
  84. HcomUtil::GetHcomGroup(NOT_NULL(anf_node), NOT_NULL(&group_));
  85. anf_node_ = anf_node;
  86. return true;
  87. }
  88. const std::vector<size_t> &HcclKernel::GetInputSizeList() const {
  89. size_t size = 0;
  90. if (!input_size_list_.empty()) {
  91. return input_size_list_;
  92. }
  93. for (ulong i = 0; i < hccl_data_type_list_.size(); ++i) {
  94. if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[i], hccl_kernel_input_shape_list_[i], &size)) {
  95. MS_LOG(ERROR) << "GetHcclOpInputSize failed";
  96. }
  97. input_size_list_.push_back(size);
  98. }
  99. return input_size_list_;
  100. }
  101. const std::vector<size_t> &HcclKernel::GetOutputSizeList() const {
  102. size_t size = 0;
  103. if (!output_size_list_.empty()) {
  104. return output_size_list_;
  105. }
  106. for (ulong i = 0; i < hccl_data_type_list_.size(); ++i) {
  107. if (!HcomUtil::GetHcclOpSize(hccl_data_type_list_[i], hccl_kernel_output_shape_list_[i], &size)) {
  108. MS_LOG(ERROR) << "GetHcclOpOutputSize failed";
  109. }
  110. output_size_list_.push_back(size);
  111. }
  112. return output_size_list_;
  113. }
  114. const std::vector<size_t> &HcclKernel::GetWorkspaceSizeList() const { return workspace_size_list_; }
  115. std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inputs,
  116. const std::vector<AddressPtr> &workspace,
  117. const std::vector<AddressPtr> &outputs, uint32_t stream_id) {
  118. if (inputs.empty() || outputs.empty()) {
  119. MS_LOG(EXCEPTION) << "Inputs or outputs is empty";
  120. }
  121. stream_id_ = stream_id;
  122. std::string hccl_type = AnfAlgo::GetCNodeName(anf_node_);
  123. MS_EXCEPTION_IF_NULL(inputs.at(0));
  124. auto input_data_addr = inputs.at(0)->addr;
  125. MS_EXCEPTION_IF_NULL(outputs.at(0));
  126. auto output_data_addr = outputs.at(0)->addr;
  127. void *workspace_address = nullptr;
  128. const int64_t workspace_num = 0;
  129. std::vector<uint8_t> private_def;
  130. hcclDataType_t data_type = hccl_data_type_list_[0];
  131. MS_LOG(INFO) << "HCCL Task : stream_id=" << stream_id << ", ws_num=" << workspace_num << ", count=" << hccl_count_
  132. << ", root_id=" << root_id_ << ", op_type=" << static_cast<int>(op_type_)
  133. << ", data_type=" << static_cast<int>(data_type);
  134. HcclTaskInfoPtr task_info_ptr = std::make_shared<HcclTaskInfo>(
  135. stream_id, hccl_type, input_data_addr, output_data_addr, workspace_address, workspace_num, 0, private_def, nullptr,
  136. hccl_count_, root_id_, op_type_, data_type, group_, RuntimeUtils::HcomBindModel, RuntimeUtils::HcomUnbindModel,
  137. RuntimeUtils::HcomDistribute);
  138. MS_EXCEPTION_IF_NULL(task_info_ptr);
  139. return {task_info_ptr};
  140. }
  141. } // namespace kernel
  142. } // namespace mindspore