You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

comm_manager.cc 3.3 kB

4 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. /**
  2. * Copyright 2019-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "include/common/utils/comm_manager.h"
  17. #include "include/common/utils/convert_utils.h"
  18. #include "utils/ms_context.h"
  19. #include "include/common/utils/parallel_context.h"
  20. namespace mindspore {
  21. namespace {
  22. constexpr auto kDefaultCommManagerName = "default_comm_manager_name";
  23. constexpr unsigned int kNoCommDlibRankSize = 2048;
  24. std::map<std::string, std::shared_ptr<CommManager>> &GetInstanceMap() {
  25. static std::map<std::string, std::shared_ptr<CommManager>> kCommInstanceMap = {};
  26. return kCommInstanceMap;
  27. }
  28. class DefaultCommManager : public CommManager {
  29. public:
  30. DefaultCommManager() : CommManager("hccl") {}
  31. ~DefaultCommManager() override = default;
  32. bool CreateGroupSync(const string &, const std::vector<unsigned int> &) const override { return true; }
  33. bool GetRankID(const string &group, unsigned int *rank_id) const override { return true; }
  34. bool GetRankSize(const string &group, unsigned int *rank_size) const override {
  35. *rank_size = kNoCommDlibRankSize;
  36. return true;
  37. }
  38. bool DestroyGroup(const string &group) const override { return true; }
  39. uint32_t GetRank() override { return 0; }
  40. };
  41. COMM_MANAGER_REG(kDefaultCommManagerName, std::make_shared<DefaultCommManager>());
  42. } // namespace
  43. bool CommManager::Register(const std::string &name, const std::shared_ptr<CommManager> &instance) {
  44. if (GetInstanceMap().find(name) != GetInstanceMap().end()) {
  45. return false;
  46. }
  47. GetInstanceMap().emplace(name, instance);
  48. return true;
  49. }
  50. CommManager &CommManager::GetInstance() noexcept {
  51. if (GetInstanceMap().empty()) {
  52. MS_LOG(EXCEPTION) << "No CommManager instance found.";
  53. }
  54. auto default_iter = GetInstanceMap().find(kDefaultCommManagerName);
  55. if (default_iter == GetInstanceMap().end()) {
  56. MS_LOG(EXCEPTION) << "Default CommManager instance not found.";
  57. }
  58. auto context_ptr = MsContext::GetInstance();
  59. MS_EXCEPTION_IF_NULL(context_ptr);
  60. std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  61. if (auto iter = GetInstanceMap().find(device_name); iter != GetInstanceMap().end()) {
  62. return *(iter->second);
  63. }
  64. if (static bool first_warning = true; first_warning) {
  65. MS_LOG(WARNING) << "CommManager instance for " << device_name << " not found, return default instance.";
  66. first_warning = false;
  67. }
  68. return *(default_iter->second);
  69. }
  70. uint32_t GetRank() { return CommManager::GetInstance().GetRank(); }
  71. bool IsStandAlone() {
  72. auto parallel_context = parallel::ParallelContext::GetInstance();
  73. MS_EXCEPTION_IF_NULL(parallel_context);
  74. return parallel_context->parallel_mode() == parallel::kStandalone;
  75. }
  76. } // namespace mindspore