You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

runtime_utils.cc 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "device/ascend/tasksink/runtime_utils.h"
  17. #include <string>
  18. #include "hccl/hcom.h"
  19. #include "utils/log_adapter.h"
  20. #include "utils/utils.h"
  21. constexpr auto kHcomBroadcast = "hcom_broadcast_";
  22. constexpr auto kHcomAllGather = "hcom_all_gather_";
  23. constexpr auto kHcomAllReduce = "hcom_all_reduce_";
  24. constexpr auto kHcomReduceScatter = "hcom_reduce_scatter_";
  25. constexpr auto kUnderline = "_";
  26. namespace mindspore {
  27. namespace device {
  28. namespace ascend {
  29. namespace tasksink {
  30. bool RuntimeUtils::HcomBindModel(rtModel_t model, rtStream_t stream) {
  31. hcclResult_t ret = hcom_bind_model(model, stream);
  32. if (ret != HCCL_SUCCESS) {
  33. MS_LOG(ERROR) << "Call hcom_bind_model failed, ret: 0x" << static_cast<int>(ret);
  34. return false;
  35. }
  36. return true;
  37. }
  38. bool RuntimeUtils::HcomUnbindModel(rtModel_t model) {
  39. hcclResult_t ret = hcom_unbind_model(model);
  40. if (ret != HCCL_SUCCESS) {
  41. MS_LOG(ERROR) << "Call hcom_unbind_model failed, ret: 0x" << static_cast<int>(ret);
  42. return false;
  43. }
  44. return true;
  45. }
  46. bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info, rtStream_t stream) {
  47. MS_LOG(INFO) << "hccl distribute start";
  48. MS_EXCEPTION_IF_NULL(task_info);
  49. hcclResult_t ret;
  50. static uint32_t task_counter = 0;
  51. auto hccl_group = task_info->group();
  52. if (task_info->hccl_type() == kBroadcastOpName) {
  53. // call hcom broadcast interface to run op
  54. const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0);
  55. ret = hcom_broadcast(tag_broadcast.c_str(), task_info->input_data_addr(), static_cast<u64>(task_info->count()),
  56. static_cast<hcclDataType_t>(task_info->data_type()), static_cast<u32>(task_info->root_id()),
  57. hccl_group.c_str(), stream);
  58. if (ret != HCCL_SUCCESS) {
  59. MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast<int>(ret);
  60. return false;
  61. }
  62. } else if (task_info->hccl_type() == kAllGatherOpName) {
  63. // call hcom allgather interface to run op
  64. const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0);
  65. ret = hcom_all_gather(tag_all_gather.c_str(), task_info->input_data_addr(), task_info->output_data_addr(),
  66. static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()),
  67. hccl_group.c_str(), stream);
  68. if (ret != HCCL_SUCCESS) {
  69. MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret;
  70. return false;
  71. }
  72. } else if (task_info->hccl_type() == kAllReduceOpName) {
  73. // call hcom allreduce interface to run op
  74. const string tag_all_reduce = kHcomAllReduce + std::to_string(task_counter++) + kUnderline + std::to_string(0);
  75. ret = hcom_all_reduce(tag_all_reduce.c_str(), task_info->input_data_addr(), task_info->output_data_addr(),
  76. static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()),
  77. static_cast<hcclRedOp_t>(task_info->op_type()), hccl_group.c_str(), stream);
  78. if (ret != HCCL_SUCCESS) {
  79. MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret;
  80. return false;
  81. }
  82. } else if (task_info->hccl_type() == kReduceScatterOpName) {
  83. // call hcom reducescatter interface to run op
  84. const string tag_reduce_scatter =
  85. kHcomReduceScatter + std::to_string(task_counter++) + kUnderline + std::to_string(0);
  86. ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), task_info->input_data_addr(), task_info->output_data_addr(),
  87. static_cast<u64>(task_info->count()), static_cast<hcclDataType_t>(task_info->data_type()),
  88. static_cast<hcclRedOp_t>(task_info->op_type()), hccl_group.c_str(), stream);
  89. if (ret != HCCL_SUCCESS) {
  90. MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret;
  91. return false;
  92. }
  93. }
  94. return true;
  95. }
  96. } // namespace tasksink
  97. } // namespace ascend
  98. } // namespace device
  99. } // namespace mindspore