You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cluster_context.cc 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <vector>
  17. #include "distributed/cluster/cluster_context.h"
  18. #include "utils/ms_context.h"
  19. #include "ps/ps_context.h"
  20. #include "debug/common.h"
  21. namespace mindspore {
  22. namespace distributed {
  23. namespace cluster {
  24. ClusterContext::ClusterContext()
  25. : inited_(false),
  26. finalized_(true),
  27. node_num_each_role_({}),
  28. scheduler_host_(kLocalHost),
  29. scheduler_port_(kDefaultSchedPort),
  30. node_(nullptr),
  31. node_role_(kEnvRoleOfWorker),
  32. cluster_config_(nullptr) {}
  33. ClusterContext::~ClusterContext() {
  34. if (!finalized_) {
  35. Finalize();
  36. }
  37. }
  38. std::shared_ptr<ClusterContext> ClusterContext::instance() {
  39. static std::shared_ptr<ClusterContext> cluster_instance = nullptr;
  40. if (cluster_instance == nullptr) {
  41. cluster_instance.reset(new (std::nothrow) ClusterContext());
  42. MS_EXCEPTION_IF_NULL(cluster_instance);
  43. }
  44. return cluster_instance;
  45. }
  46. bool ClusterContext::Initialize() {
  47. if (inited_) {
  48. MS_LOG(INFO) << "The cluster has been initialized.";
  49. return true;
  50. }
  51. // Step 1: Initialize cluster configuration.
  52. InitClusterConfig();
  53. // Step 2: Build network for this cluster. Every process will block in this method until networking is done.
  54. if (!BuildCluster()) {
  55. MS_LOG(EXCEPTION) << "Building networking for " << node_role_ << " failed.";
  56. return false;
  57. }
  58. inited_ = true;
  59. finalized_ = false;
  60. return true;
  61. }
  62. bool ClusterContext::Finalize() {
  63. if (finalized_) {
  64. return true;
  65. }
  66. // In some cases, one node calls the Finish function while other nodes don't. So timeout is acceptable.
  67. if (!node_->Finish()) {
  68. MS_LOG(WARNING) << "Finishing node " << node_role_ << " timeout.";
  69. }
  70. if (!node_->Stop()) {
  71. MS_LOG(ERROR) << "Failed to stop node " << node_role_;
  72. return false;
  73. }
  74. finalized_ = true;
  75. wait_finish_cond_.notify_all();
  76. return true;
  77. }
  78. const std::shared_ptr<ps::core::Node> &ClusterContext::node() const { return node_; }
  79. void ClusterContext::InitClusterConfig() {
  80. InitNodeRole();
  81. InitSchedulerIp();
  82. InitSchedulerPort();
  83. ps::PSContext::instance()->set_worker_num(node_num_each_role_[kEnvRoleOfWorker]);
  84. ps::PSContext::instance()->set_server_num(node_num_each_role_[kEnvRoleOfServer]);
  85. ps::PSContext::instance()->set_scheduler_ip(scheduler_host_);
  86. ps::PSContext::instance()->set_scheduler_port(scheduler_port_);
  87. ps::PSContext::instance()->cluster_config().initial_worker_num = node_num_each_role_[kEnvRoleOfWorker];
  88. ps::PSContext::instance()->cluster_config().initial_server_num = node_num_each_role_[kEnvRoleOfServer];
  89. ps::PSContext::instance()->cluster_config().scheduler_host = scheduler_host_;
  90. ps::PSContext::instance()->cluster_config().scheduler_port = scheduler_port_;
  91. }
  92. bool ClusterContext::BuildCluster() {
  93. // Create node according to different role.
  94. if (node_role_ == kEnvRoleOfWorker) {
  95. node_ = std::make_shared<ps::core::WorkerNode>();
  96. } else if (node_role_ == kEnvRoleOfServer) {
  97. node_ = std::make_shared<ps::core::ServerNode>();
  98. } else if (node_role_ == kEnvRoleOfScheduler) {
  99. node_ = std::make_shared<ps::core::SchedulerNode>();
  100. } else {
  101. MS_LOG(EXCEPTION) << "The role " << node_role_ << " is invalid.";
  102. return false;
  103. }
  104. MS_EXCEPTION_IF_NULL(node_);
  105. RegisterEventCallback();
  106. if (!node_->Start()) {
  107. MS_LOG(EXCEPTION) << "Building network failed.";
  108. return false;
  109. }
  110. MS_LOG(INFO) << "Cluster is successfully initialized.";
  111. return true;
  112. }
  113. void ClusterContext::InitNodeRole() {
  114. node_role_ = common::GetEnv(kEnvRole);
  115. if (kValidRoleName.count(node_role_) == 0) {
  116. MS_LOG(EXCEPTION) << "Role name " << node_role_ << " is invalid.";
  117. return;
  118. }
  119. if (common::GetEnv(kEnvWorkerNum).empty()) {
  120. node_num_each_role_[kEnvRoleOfWorker] = 0;
  121. } else {
  122. TRY_AND_CATCH_WITH_EXCEPTION(
  123. (node_num_each_role_[kEnvRoleOfWorker] = IntToUint(std::stoi(common::GetEnv(kEnvWorkerNum)))),
  124. "The environment variable MS_WORKER_NUM is invalid.");
  125. }
  126. if (common::GetEnv(kEnvServerNum).empty()) {
  127. node_num_each_role_[kEnvRoleOfServer] = 0;
  128. } else {
  129. TRY_AND_CATCH_WITH_EXCEPTION(
  130. (node_num_each_role_[kEnvRoleOfServer] = IntToUint(std::stoi(common::GetEnv(kEnvServerNum)))),
  131. "The environment variable MS_SERVER_NUM is invalid.");
  132. }
  133. }
  134. void ClusterContext::InitSchedulerIp() {
  135. scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
  136. if (scheduler_host_ != kLocalHost) {
  137. MS_LOG(EXCEPTION) << "Scheduler IP should be 127.0.0.1";
  138. }
  139. }
  140. void ClusterContext::InitSchedulerPort() {
  141. TRY_AND_CATCH_WITH_EXCEPTION((scheduler_port_ = static_cast<uint16_t>(std::stoi(common::GetEnv(kEnvSchedulerPort)))),
  142. "The environment variable MS_SCHED_PORT is invalid.");
  143. if (scheduler_port_ > kMaxPort) {
  144. MS_LOG(EXCEPTION) << "The port: " << scheduler_port_ << " is invalid.";
  145. }
  146. }
  147. void ClusterContext::RegisterEventCallback() {
  148. auto abstract_node = std::dynamic_pointer_cast<ps::core::AbstractNode>(node_);
  149. if (abstract_node != nullptr) {
  150. abstract_node->RegisterEventCallback(ps::core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
  151. MS_LOG(ERROR) << "Event SCHEDULER_TIMEOUT is captured.";
  152. Finalize();
  153. try {
  154. MS_LOG(EXCEPTION)
  155. << "Event SCHEDULER_TIMEOUT is captured. This is because scheduler node is finalized or crashed.";
  156. } catch (std::exception &) {
  157. MsException::Instance().SetException();
  158. }
  159. });
  160. abstract_node->RegisterEventCallback(ps::core::ClusterEvent::NODE_TIMEOUT, [this]() {
  161. MS_LOG(ERROR) << "Event NODE_TIMEOUT is captured.";
  162. Finalize();
  163. try {
  164. MS_LOG(EXCEPTION) << "Event NODE_TIMEOUT is captured. This is because some nodes are finalized or crashed.";
  165. } catch (std::exception &) {
  166. MsException::Instance().SetException();
  167. }
  168. });
  169. }
  170. }
  171. } // namespace cluster
  172. } // namespace distributed
  173. } // namespace mindspore