You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ps_context.cc 16 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ps/ps_context.h"
  17. #include "utils/log_adapter.h"
  18. #include "utils/ms_utils.h"
  19. #include "backend/kernel_compiler/kernel.h"
  20. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  21. #include "ps/ps_cache/ps_cache_manager.h"
  22. #include "ps/ps_cache/ps_data/ps_data_prefetch.h"
  23. #endif
  24. namespace mindspore {
  25. namespace ps {
  26. std::shared_ptr<PSContext> PSContext::instance() {
  27. static std::shared_ptr<PSContext> ps_instance = nullptr;
  28. if (ps_instance == nullptr) {
  29. ps_instance.reset(new (std::nothrow) PSContext());
  30. }
  31. return ps_instance;
  32. }
  33. void PSContext::SetPSEnable(bool enabled) {
  34. ps_enabled_ = enabled;
  35. if (ps_enabled_) {
  36. std::string ms_role = common::GetEnv(kEnvRole);
  37. MS_LOG(INFO) << "PS mode is enabled. MS_ROLE is " << ms_role;
  38. if (ms_role == kEnvRoleOfWorker) {
  39. is_worker_ = true;
  40. } else if (ms_role == kEnvRoleOfPServer) {
  41. is_pserver_ = true;
  42. } else if (ms_role == kEnvRoleOfScheduler) {
  43. is_sched_ = true;
  44. } else {
  45. MS_LOG(INFO) << "MS_ROLE is " << ms_role;
  46. }
  47. worker_num_ = std::strtol(common::GetEnv(kEnvWorkerNum).c_str(), nullptr, kBase);
  48. server_num_ = std::strtol(common::GetEnv(kEnvPServerNum).c_str(), nullptr, kBase);
  49. scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
  50. if (scheduler_host_.length() > kLength) {
  51. MS_LOG(EXCEPTION) << "The scheduler host's length can not exceed " << kLength;
  52. }
  53. scheduler_port_ = std::strtol(common::GetEnv(kEnvSchedulerPort).c_str(), nullptr, kBase);
  54. if (scheduler_port_ > kMaxPort) {
  55. MS_LOG(EXCEPTION) << "The port: " << scheduler_port_ << " is illegal.";
  56. }
  57. scheduler_manage_port_ =
  58. static_cast<uint16_t>((std::strtol(common::GetEnv(kEnvSchedulerManagePort).c_str(), nullptr, kBase)));
  59. if (scheduler_manage_port_ > kMaxPort) {
  60. MS_LOG(EXCEPTION) << "The port << " << scheduler_manage_port_ << " is illegal.";
  61. }
  62. cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
  63. node_id_ = common::GetEnv(kEnvNodeId);
  64. if (node_id_.length() > kLength) {
  65. MS_LOG(EXCEPTION) << "The node id length can not exceed " << kLength;
  66. }
  67. } else {
  68. MS_LOG(INFO) << "PS mode is disabled.";
  69. is_worker_ = false;
  70. is_pserver_ = false;
  71. is_sched_ = false;
  72. }
  73. }
  74. bool PSContext::is_ps_mode() const {
  75. if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
  76. return true;
  77. }
  78. return ps_enabled_;
  79. }
  80. void PSContext::Reset() {
  81. ps_enabled_ = false;
  82. is_worker_ = false;
  83. is_pserver_ = false;
  84. is_sched_ = false;
  85. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  86. if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
  87. ps_cache_instance.Finalize();
  88. set_cache_enable(false);
  89. }
  90. #endif
  91. }
  92. std::string PSContext::ms_role() const {
  93. if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
  94. return role_;
  95. }
  96. if (is_worker_) {
  97. return kEnvRoleOfWorker;
  98. } else if (is_pserver_) {
  99. return kEnvRoleOfPServer;
  100. } else if (is_sched_) {
  101. return kEnvRoleOfScheduler;
  102. } else {
  103. return kEnvRoleOfNotPS;
  104. }
  105. }
  106. bool PSContext::is_worker() const {
  107. if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
  108. return role_ == kEnvRoleOfWorker;
  109. }
  110. return is_worker_;
  111. }
  112. bool PSContext::is_server() const {
  113. if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
  114. return role_ == kEnvRoleOfServer;
  115. }
  116. return is_pserver_;
  117. }
  118. bool PSContext::is_scheduler() const {
  119. if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
  120. return role_ == kEnvRoleOfScheduler;
  121. }
  122. return is_sched_;
  123. }
  124. uint32_t PSContext::initial_worker_num() const { return worker_num_; }
  125. uint32_t PSContext::initial_server_num() const { return server_num_; }
  126. std::string PSContext::scheduler_host() const { return scheduler_host_; }
  127. void PSContext::SetPSRankId(uint32_t rank_id) { rank_id_ = rank_id; }
  128. uint32_t PSContext::ps_rank_id() const { return rank_id_; }
  129. void PSContext::InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
  130. size_t vocab_size) const {
  131. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  132. ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size);
  133. #endif
  134. }
  135. void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
  136. size_t cache_vocab_size, size_t embedding_size) const {
  137. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  138. ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size);
  139. #endif
  140. }
  141. void PSContext::InsertWeightInitInfo(const std::string &param_name, size_t global_seed, size_t op_seed) const {
  142. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  143. ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed);
  144. #endif
  145. }
  146. void PSContext::InsertAccumuInitInfo(const std::string &param_name, float init_val) const {
  147. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  148. ps_cache_instance.InsertAccumuInitInfo(param_name, init_val);
  149. #endif
  150. }
  151. void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const {
  152. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  153. ps_cache_instance.CloneHashTable(dest_param_name, src_param_name);
  154. #endif
  155. }
  156. void PSContext::set_cache_enable(bool cache_enable) const {
  157. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  158. PsDataPrefetch::GetInstance().set_cache_enable(cache_enable);
  159. #endif
  160. }
  161. void PSContext::set_rank_id(uint32_t rank_id) const {
  162. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  163. ps_cache_instance.set_rank_id(rank_id);
  164. #endif
  165. }
  166. void PSContext::set_server_mode(const std::string &server_mode) {
  167. if (server_mode != kServerModePS && server_mode != kServerModeFL && server_mode != kServerModeHybrid) {
  168. MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
  169. << " or " << kServerModeHybrid;
  170. return;
  171. }
  172. MS_LOG(INFO) << "Server mode: " << server_mode << " is used for Server and Worker. Scheduler will ignore it.";
  173. server_mode_ = server_mode;
  174. }
  175. const std::string &PSContext::server_mode() const { return server_mode_; }
  176. void PSContext::set_encrypt_type(const std::string &encrypt_type) {
  177. if (encrypt_type != kNotEncryptType && encrypt_type != kDPEncryptType && encrypt_type != kPWEncryptType &&
  178. encrypt_type != kStablePWEncryptType) {
  179. MS_LOG(EXCEPTION) << encrypt_type << " is invalid. Encrypt type must be " << kNotEncryptType << " or "
  180. << kDPEncryptType << " or " << kPWEncryptType << " or " << kStablePWEncryptType;
  181. return;
  182. }
  183. encrypt_type_ = encrypt_type;
  184. }
  185. const std::string &PSContext::encrypt_type() const { return encrypt_type_; }
  186. void PSContext::set_dp_eps(float dp_eps) {
  187. if (dp_eps > 0) {
  188. dp_eps_ = dp_eps;
  189. } else {
  190. MS_LOG(EXCEPTION) << dp_eps << " is invalid, dp_eps must be larger than 0.";
  191. return;
  192. }
  193. }
  194. float PSContext::dp_eps() const { return dp_eps_; }
  195. void PSContext::set_dp_delta(float dp_delta) {
  196. if (dp_delta > 0 && dp_delta < 1) {
  197. dp_delta_ = dp_delta;
  198. } else {
  199. MS_LOG(EXCEPTION) << dp_delta << " is invalid, dp_delta must be in range of (0, 1).";
  200. return;
  201. }
  202. }
  203. float PSContext::dp_delta() const { return dp_delta_; }
  204. void PSContext::set_dp_norm_clip(float dp_norm_clip) {
  205. if (dp_norm_clip > 0) {
  206. dp_norm_clip_ = dp_norm_clip;
  207. } else {
  208. MS_LOG(EXCEPTION) << dp_norm_clip << " is invalid, dp_norm_clip must be larger than 0.";
  209. return;
  210. }
  211. }
  212. float PSContext::dp_norm_clip() const { return dp_norm_clip_; }
  213. void PSContext::set_ms_role(const std::string &role) {
  214. if (server_mode_ != kServerModeFL && server_mode_ != kServerModeHybrid) {
  215. MS_LOG(EXCEPTION) << "Only federated learning supports to set role by fl context.";
  216. return;
  217. }
  218. if (role != kEnvRoleOfWorker && role != kEnvRoleOfServer && role != kEnvRoleOfScheduler) {
  219. MS_LOG(EXCEPTION) << "ms_role " << role << " is invalid.";
  220. return;
  221. }
  222. role_ = role;
  223. }
  224. void PSContext::set_worker_num(uint32_t worker_num) {
  225. // Hybrid training mode only supports one worker for now.
  226. if (server_mode_ == kServerModeHybrid && worker_num != 1) {
  227. MS_LOG(EXCEPTION) << "The worker number should be set to 1 in hybrid training mode.";
  228. return;
  229. }
  230. worker_num_ = worker_num;
  231. }
  232. uint32_t PSContext::worker_num() const { return worker_num_; }
  233. void PSContext::set_server_num(uint32_t server_num) {
  234. if (server_num == 0) {
  235. MS_LOG(EXCEPTION) << "Server number must be greater than 0.";
  236. return;
  237. }
  238. server_num_ = server_num;
  239. }
  240. uint32_t PSContext::server_num() const { return server_num_; }
  241. void PSContext::set_scheduler_ip(const std::string &sched_ip) { scheduler_host_ = sched_ip; }
  242. std::string PSContext::scheduler_ip() const { return scheduler_host_; }
  243. void PSContext::set_scheduler_port(uint16_t sched_port) { scheduler_port_ = sched_port; }
  244. uint16_t PSContext::scheduler_port() const { return scheduler_port_; }
  245. void PSContext::GenerateResetterRound() {
  246. uint32_t binary_server_context = 0;
  247. bool is_parameter_server_mode = false;
  248. bool is_federated_learning_mode = false;
  249. bool is_mixed_training_mode = false;
  250. bool use_pairwise_encrypt = (encrypt_type_ == kPWEncryptType);
  251. if (server_mode_ == kServerModePS) {
  252. is_parameter_server_mode = true;
  253. } else if (server_mode_ == kServerModeFL) {
  254. is_federated_learning_mode = true;
  255. } else if (server_mode_ == kServerModeHybrid) {
  256. is_mixed_training_mode = true;
  257. } else {
  258. MS_LOG(EXCEPTION) << server_mode_ << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
  259. << " or " << kServerModeHybrid;
  260. return;
  261. }
  262. binary_server_context = ((unsigned int)is_parameter_server_mode) | ((unsigned int)is_federated_learning_mode << 1) |
  263. ((unsigned int)is_mixed_training_mode << 2) | ((unsigned int)use_pairwise_encrypt << 3);
  264. if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
  265. resetter_round_ = ResetterRound::kNoNeedToReset;
  266. } else {
  267. resetter_round_ = kServerContextToResetRoundMap.at(binary_server_context);
  268. }
  269. MS_LOG(INFO) << "Server context is " << binary_server_context << ". Resetter round is " << resetter_round_;
  270. return;
  271. }
  272. ResetterRound PSContext::resetter_round() const { return resetter_round_; }
  273. void PSContext::set_fl_server_port(uint16_t fl_server_port) { fl_server_port_ = fl_server_port; }
  274. uint16_t PSContext::fl_server_port() const { return fl_server_port_; }
  275. void PSContext::set_fl_client_enable(bool enabled) { fl_client_enable_ = enabled; }
  276. bool PSContext::fl_client_enable() const { return fl_client_enable_; }
  277. void PSContext::set_start_fl_job_threshold(uint64_t start_fl_job_threshold) {
  278. start_fl_job_threshold_ = start_fl_job_threshold;
  279. }
  280. uint64_t PSContext::start_fl_job_threshold() const { return start_fl_job_threshold_; }
  281. void PSContext::set_start_fl_job_time_window(uint64_t start_fl_job_time_window) {
  282. start_fl_job_time_window_ = start_fl_job_time_window;
  283. }
  284. uint64_t PSContext::start_fl_job_time_window() const { return start_fl_job_time_window_; }
  285. void PSContext::set_update_model_ratio(float update_model_ratio) {
  286. if (update_model_ratio > 1.0) {
  287. MS_LOG(EXCEPTION) << "update_model_ratio must be between 0 and 1.";
  288. return;
  289. }
  290. update_model_ratio_ = update_model_ratio;
  291. }
  292. float PSContext::update_model_ratio() const { return update_model_ratio_; }
  293. void PSContext::set_update_model_time_window(uint64_t update_model_time_window) {
  294. update_model_time_window_ = update_model_time_window;
  295. }
  296. uint64_t PSContext::update_model_time_window() const { return update_model_time_window_; }
  297. void PSContext::set_share_secrets_ratio(float share_secrets_ratio) {
  298. if (share_secrets_ratio > 0 && share_secrets_ratio <= 1) {
  299. share_secrets_ratio_ = share_secrets_ratio;
  300. } else {
  301. MS_LOG(EXCEPTION) << share_secrets_ratio << " is invalid, share_secrets_ratio must be in range of (0, 1].";
  302. return;
  303. }
  304. }
  305. float PSContext::share_secrets_ratio() const { return share_secrets_ratio_; }
  306. void PSContext::set_cipher_time_window(uint64_t cipher_time_window) {
  307. if (cipher_time_window_ < 0) {
  308. MS_LOG(EXCEPTION) << "cipher_time_window should not be less than 0.";
  309. return;
  310. }
  311. cipher_time_window_ = cipher_time_window;
  312. }
  313. uint64_t PSContext::cipher_time_window() const { return cipher_time_window_; }
  314. void PSContext::set_reconstruct_secrets_threshold(uint64_t reconstruct_secrets_threshold) {
  315. if (reconstruct_secrets_threshold == 0) {
  316. MS_LOG(EXCEPTION) << "reconstruct_secrets_threshold should be positive.";
  317. return;
  318. }
  319. reconstruct_secrets_threshold_ = reconstruct_secrets_threshold;
  320. }
  321. uint64_t PSContext::reconstruct_secrets_threshold() const { return reconstruct_secrets_threshold_; }
  322. void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
  323. const std::string &PSContext::fl_name() const { return fl_name_; }
  324. void PSContext::set_fl_iteration_num(uint64_t fl_iteration_num) { fl_iteration_num_ = fl_iteration_num; }
  325. uint64_t PSContext::fl_iteration_num() const { return fl_iteration_num_; }
  326. void PSContext::set_client_epoch_num(uint64_t client_epoch_num) { client_epoch_num_ = client_epoch_num; }
  327. uint64_t PSContext::client_epoch_num() const { return client_epoch_num_; }
  328. void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch_size_ = client_batch_size; }
  329. uint64_t PSContext::client_batch_size() const { return client_batch_size_; }
  330. void PSContext::set_client_learning_rate(float client_learning_rate) { client_learning_rate_ = client_learning_rate; }
  331. float PSContext::client_learning_rate() const { return client_learning_rate_; }
  332. void PSContext::set_worker_step_num_per_iteration(uint64_t worker_step_num_per_iteration) {
  333. worker_step_num_per_iteration_ = worker_step_num_per_iteration;
  334. }
  335. uint64_t PSContext::worker_step_num_per_iteration() const { return worker_step_num_per_iteration_; }
  336. bool PSContext::enable_ssl() const { return enable_ssl_; }
  337. void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; }
  338. core::ClusterConfig &PSContext::cluster_config() {
  339. if (cluster_config_ == nullptr) {
  340. cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
  341. MS_EXCEPTION_IF_NULL(cluster_config_);
  342. }
  343. return *cluster_config_;
  344. }
  345. void PSContext::set_scheduler_manage_port(uint16_t sched_port) { scheduler_manage_port_ = sched_port; }
  346. uint16_t PSContext::scheduler_manage_port() const { return scheduler_manage_port_; }
  347. void PSContext::set_config_file_path(const std::string &path) { config_file_path_ = path; }
  348. std::string PSContext::config_file_path() const { return config_file_path_; }
  349. void PSContext::set_node_id(const std::string &node_id) { node_id_ = node_id; }
  350. const std::string &PSContext::node_id() const { return node_id_; }
  351. std::string PSContext::client_password() const { return client_password_; }
  352. void PSContext::set_client_password(const std::string &password) { client_password_ = password; }
  353. std::string PSContext::server_password() const { return server_password_; }
  354. void PSContext::set_server_password(const std::string &password) { server_password_ = password; }
  355. } // namespace ps
  356. } // namespace mindspore