You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ps_context.cc 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ps/ps_context.h"
  17. #include "utils/log_adapter.h"
  18. #include "utils/ms_utils.h"
  19. #include "backend/kernel_compiler/kernel.h"
  20. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  21. #include "ps/ps_cache/ps_cache_manager.h"
  22. #include "ps/ps_cache/ps_data/ps_data_prefetch.h"
  23. #endif
  24. namespace mindspore {
  25. namespace ps {
  26. std::shared_ptr<PSContext> PSContext::instance() {
  27. static std::shared_ptr<PSContext> ps_instance = nullptr;
  28. if (ps_instance == nullptr) {
  29. ps_instance.reset(new (std::nothrow) PSContext());
  30. }
  31. return ps_instance;
  32. }
  33. void PSContext::SetPSEnable(bool enabled) {
  34. ps_enabled_ = enabled;
  35. if (ps_enabled_) {
  36. std::string ms_role = common::GetEnv(kEnvRole);
  37. MS_LOG(INFO) << "PS mode is enabled. MS_ROLE is " << ms_role;
  38. if (ms_role == kEnvRoleOfWorker) {
  39. is_worker_ = true;
  40. } else if (ms_role == kEnvRoleOfPServer) {
  41. is_pserver_ = true;
  42. } else if (ms_role == kEnvRoleOfScheduler) {
  43. is_sched_ = true;
  44. } else {
  45. MS_LOG(INFO) << "MS_ROLE is " << ms_role;
  46. }
  47. worker_num_ = std::strtol(common::GetEnv(kEnvWorkerNum).c_str(), nullptr, kBase);
  48. server_num_ = std::strtol(common::GetEnv(kEnvPServerNum).c_str(), nullptr, kBase);
  49. scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
  50. scheduler_port_ = std::strtol(common::GetEnv(kEnvSchedulerPort).c_str(), nullptr, kBase);
  51. scheduler_manage_port_ = std::strtol(common::GetEnv(kEnvSchedulerManagePort).c_str(), nullptr, kBase);
  52. cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
  53. node_id_ = common::GetEnv(kEnvNodeId);
  54. } else {
  55. MS_LOG(INFO) << "PS mode is disabled.";
  56. is_worker_ = false;
  57. is_pserver_ = false;
  58. is_sched_ = false;
  59. }
  60. }
  61. bool PSContext::is_ps_mode() const {
  62. if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
  63. return true;
  64. }
  65. return ps_enabled_;
  66. }
  67. void PSContext::Reset() {
  68. ps_enabled_ = false;
  69. is_worker_ = false;
  70. is_pserver_ = false;
  71. is_sched_ = false;
  72. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  73. if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
  74. ps_cache_instance.Finalize();
  75. set_cache_enable(false);
  76. }
  77. #endif
  78. }
  79. std::string PSContext::ms_role() const {
  80. if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
  81. return role_;
  82. }
  83. if (is_worker_) {
  84. return kEnvRoleOfWorker;
  85. } else if (is_pserver_) {
  86. return kEnvRoleOfPServer;
  87. } else if (is_sched_) {
  88. return kEnvRoleOfScheduler;
  89. } else {
  90. return kEnvRoleOfNotPS;
  91. }
  92. }
  93. bool PSContext::is_worker() const {
  94. if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
  95. return role_ == kEnvRoleOfWorker;
  96. }
  97. return is_worker_;
  98. }
  99. bool PSContext::is_server() const {
  100. if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
  101. return role_ == kEnvRoleOfServer;
  102. }
  103. return is_pserver_;
  104. }
  105. bool PSContext::is_scheduler() const {
  106. if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
  107. return role_ == kEnvRoleOfScheduler;
  108. }
  109. return is_sched_;
  110. }
  111. uint32_t PSContext::initial_worker_num() const { return worker_num_; }
  112. uint32_t PSContext::initial_server_num() const { return server_num_; }
  113. std::string PSContext::scheduler_host() const { return scheduler_host_; }
  114. void PSContext::SetPSRankId(uint32_t rank_id) { rank_id_ = rank_id; }
  115. uint32_t PSContext::ps_rank_id() const { return rank_id_; }
  116. void PSContext::InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
  117. size_t vocab_size) const {
  118. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  119. ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size);
  120. #endif
  121. }
  122. void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
  123. size_t cache_vocab_size, size_t embedding_size) const {
  124. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  125. ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size);
  126. #endif
  127. }
  128. void PSContext::InsertWeightInitInfo(const std::string &param_name, size_t global_seed, size_t op_seed) const {
  129. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  130. ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed);
  131. #endif
  132. }
  133. void PSContext::InsertAccumuInitInfo(const std::string &param_name, float init_val) const {
  134. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  135. ps_cache_instance.InsertAccumuInitInfo(param_name, init_val);
  136. #endif
  137. }
  138. void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const {
  139. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  140. ps_cache_instance.CloneHashTable(dest_param_name, src_param_name);
  141. #endif
  142. }
  143. void PSContext::set_cache_enable(bool cache_enable) const {
  144. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  145. PsDataPrefetch::GetInstance().set_cache_enable(cache_enable);
  146. #endif
  147. }
  148. void PSContext::set_rank_id(uint32_t rank_id) const {
  149. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  150. ps_cache_instance.set_rank_id(rank_id);
  151. #endif
  152. }
  153. void PSContext::set_server_mode(const std::string &server_mode) {
  154. if (server_mode != kServerModePS && server_mode != kServerModeFL && server_mode != kServerModeHybrid) {
  155. MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
  156. << " or " << kServerModeHybrid;
  157. return;
  158. }
  159. MS_LOG(INFO) << "Server mode: " << server_mode << " is used for Server and Worker. Scheduler will ignore it.";
  160. server_mode_ = server_mode;
  161. }
  162. const std::string &PSContext::server_mode() const { return server_mode_; }
  163. void PSContext::set_encrypt_type(const std::string &encrypt_type) {
  164. if (encrypt_type != kNotEncryptType && encrypt_type != kDPEncryptType && encrypt_type != kPWEncryptType) {
  165. MS_LOG(EXCEPTION) << encrypt_type << " is invalid. Encrypt type must be " << kNotEncryptType << " or "
  166. << kDPEncryptType << " or " << kPWEncryptType;
  167. return;
  168. }
  169. encrypt_type_ = encrypt_type;
  170. }
  171. const std::string &PSContext::encrypt_type() const { return encrypt_type_; }
  172. void PSContext::set_dp_eps(float dp_eps) {
  173. if (dp_eps > 0) {
  174. dp_eps_ = dp_eps;
  175. } else {
  176. MS_LOG(EXCEPTION) << dp_eps << " is invalid, dp_eps must be larger than 0.";
  177. return;
  178. }
  179. }
  180. float PSContext::dp_eps() const { return dp_eps_; }
  181. void PSContext::set_dp_delta(float dp_delta) {
  182. if (dp_delta > 0 && dp_delta < 1) {
  183. dp_delta_ = dp_delta;
  184. } else {
  185. MS_LOG(EXCEPTION) << dp_delta << " is invalid, dp_delta must be in range of (0, 1).";
  186. return;
  187. }
  188. }
  189. float PSContext::dp_delta() const { return dp_delta_; }
  190. void PSContext::set_dp_norm_clip(float dp_norm_clip) {
  191. if (dp_norm_clip > 0) {
  192. dp_norm_clip_ = dp_norm_clip;
  193. } else {
  194. MS_LOG(EXCEPTION) << dp_norm_clip << " is invalid, dp_norm_clip must be larger than 0.";
  195. return;
  196. }
  197. }
  198. float PSContext::dp_norm_clip() const { return dp_norm_clip_; }
  199. void PSContext::set_ms_role(const std::string &role) {
  200. if (server_mode_ != kServerModeFL && server_mode_ != kServerModeHybrid) {
  201. MS_LOG(EXCEPTION) << "Only federated learning supports to set role by fl context.";
  202. return;
  203. }
  204. if (role != kEnvRoleOfWorker && role != kEnvRoleOfServer && role != kEnvRoleOfScheduler) {
  205. MS_LOG(EXCEPTION) << "ms_role " << role << " is invalid.";
  206. return;
  207. }
  208. role_ = role;
  209. }
  210. void PSContext::set_worker_num(uint32_t worker_num) {
  211. // Hybrid training mode only supports one worker for now.
  212. if (server_mode_ == kServerModeHybrid && worker_num != 1) {
  213. MS_LOG(EXCEPTION) << "The worker number should be set to 1 in hybrid training mode.";
  214. return;
  215. }
  216. worker_num_ = worker_num;
  217. }
  218. uint32_t PSContext::worker_num() const { return worker_num_; }
  219. void PSContext::set_server_num(uint32_t server_num) {
  220. if (server_num == 0) {
  221. MS_LOG(EXCEPTION) << "Server number must be greater than 0.";
  222. return;
  223. }
  224. server_num_ = server_num;
  225. }
  226. uint32_t PSContext::server_num() const { return server_num_; }
  227. void PSContext::set_scheduler_ip(const std::string &sched_ip) { scheduler_host_ = sched_ip; }
  228. std::string PSContext::scheduler_ip() const { return scheduler_host_; }
  229. void PSContext::set_scheduler_port(uint16_t sched_port) { scheduler_port_ = sched_port; }
  230. uint16_t PSContext::scheduler_port() const { return scheduler_port_; }
  231. void PSContext::GenerateResetterRound() {
  232. uint32_t binary_server_context = 0;
  233. bool is_parameter_server_mode = false;
  234. bool is_federated_learning_mode = false;
  235. bool is_mixed_training_mode = false;
  236. bool use_pairwise_encrypt = (encrypt_type_ == kPWEncryptType);
  237. if (server_mode_ == kServerModePS) {
  238. is_parameter_server_mode = true;
  239. } else if (server_mode_ == kServerModeFL) {
  240. is_federated_learning_mode = true;
  241. } else if (server_mode_ == kServerModeHybrid) {
  242. is_mixed_training_mode = true;
  243. } else {
  244. MS_LOG(EXCEPTION) << server_mode_ << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
  245. << " or " << kServerModeHybrid;
  246. return;
  247. }
  248. binary_server_context = ((unsigned int)is_parameter_server_mode << 0) |
  249. ((unsigned int)is_federated_learning_mode << 1) |
  250. ((unsigned int)is_mixed_training_mode << 2) | ((unsigned int)use_pairwise_encrypt << 3);
  251. if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
  252. resetter_round_ = ResetterRound::kNoNeedToReset;
  253. } else {
  254. resetter_round_ = kServerContextToResetRoundMap.at(binary_server_context);
  255. }
  256. MS_LOG(INFO) << "Server context is " << binary_server_context << ". Resetter round is " << resetter_round_;
  257. return;
  258. }
  259. ResetterRound PSContext::resetter_round() const { return resetter_round_; }
  260. void PSContext::set_fl_server_port(uint16_t fl_server_port) { fl_server_port_ = fl_server_port; }
  261. uint16_t PSContext::fl_server_port() const { return fl_server_port_; }
  262. void PSContext::set_fl_client_enable(bool enabled) { fl_client_enable_ = enabled; }
  263. bool PSContext::fl_client_enable() { return fl_client_enable_; }
  264. void PSContext::set_start_fl_job_threshold(uint64_t start_fl_job_threshold) {
  265. start_fl_job_threshold_ = start_fl_job_threshold;
  266. }
  267. uint64_t PSContext::start_fl_job_threshold() const { return start_fl_job_threshold_; }
  268. void PSContext::set_start_fl_job_time_window(uint64_t start_fl_job_time_window) {
  269. start_fl_job_time_window_ = start_fl_job_time_window;
  270. }
  271. uint64_t PSContext::start_fl_job_time_window() const { return start_fl_job_time_window_; }
  272. void PSContext::set_update_model_ratio(float update_model_ratio) {
  273. if (update_model_ratio > 1.0) {
  274. MS_LOG(EXCEPTION) << "update_model_ratio must be between 0 and 1.";
  275. return;
  276. }
  277. update_model_ratio_ = update_model_ratio;
  278. }
  279. float PSContext::update_model_ratio() const { return update_model_ratio_; }
  280. void PSContext::set_update_model_time_window(uint64_t update_model_time_window) {
  281. update_model_time_window_ = update_model_time_window;
  282. }
  283. uint64_t PSContext::update_model_time_window() const { return update_model_time_window_; }
  284. void PSContext::set_share_secrets_ratio(float share_secrets_ratio) {
  285. if (share_secrets_ratio > 0 && share_secrets_ratio <= 1) {
  286. share_secrets_ratio_ = share_secrets_ratio;
  287. } else {
  288. MS_LOG(EXCEPTION) << share_secrets_ratio << " is invalid, share_secrets_ratio must be in range of (0, 1].";
  289. return;
  290. }
  291. }
  292. float PSContext::share_secrets_ratio() const { return share_secrets_ratio_; }
  293. void PSContext::set_cipher_time_window(uint64_t cipher_time_window) {
  294. if (cipher_time_window_ < 0) {
  295. MS_LOG(EXCEPTION) << "cipher_time_window should not be less than 0.";
  296. return;
  297. }
  298. cipher_time_window_ = cipher_time_window;
  299. }
  300. uint64_t PSContext::cipher_time_window() const { return cipher_time_window_; }
  301. void PSContext::set_reconstruct_secrets_threshold(uint64_t reconstruct_secrets_threshold) {
  302. if (reconstruct_secrets_threshold == 0) {
  303. MS_LOG(EXCEPTION) << "reconstruct_secrets_threshold should be positive.";
  304. return;
  305. }
  306. reconstruct_secrets_threshold_ = reconstruct_secrets_threshold;
  307. }
  308. uint64_t PSContext::reconstruct_secrets_threshold() const { return reconstruct_secrets_threshold_; }
  309. void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
  310. const std::string &PSContext::fl_name() const { return fl_name_; }
  311. void PSContext::set_fl_iteration_num(uint64_t fl_iteration_num) { fl_iteration_num_ = fl_iteration_num; }
  312. uint64_t PSContext::fl_iteration_num() const { return fl_iteration_num_; }
  313. void PSContext::set_client_epoch_num(uint64_t client_epoch_num) { client_epoch_num_ = client_epoch_num; }
  314. uint64_t PSContext::client_epoch_num() const { return client_epoch_num_; }
  315. void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch_size_ = client_batch_size; }
  316. uint64_t PSContext::client_batch_size() const { return client_batch_size_; }
  317. void PSContext::set_client_learning_rate(float client_learning_rate) { client_learning_rate_ = client_learning_rate; }
  318. float PSContext::client_learning_rate() const { return client_learning_rate_; }
  319. void PSContext::set_worker_step_num_per_iteration(uint64_t worker_step_num_per_iteration) {
  320. worker_step_num_per_iteration_ = worker_step_num_per_iteration;
  321. }
  322. uint64_t PSContext::worker_step_num_per_iteration() const { return worker_step_num_per_iteration_; }
  323. bool PSContext::enable_ssl() const { return enable_ssl_; }
  324. void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; }
  325. core::ClusterConfig &PSContext::cluster_config() {
  326. if (cluster_config_ == nullptr) {
  327. MS_LOG(EXCEPTION) << "The cluster config is empty.";
  328. }
  329. return *cluster_config_;
  330. }
  331. void PSContext::set_scheduler_manage_port(uint16_t sched_port) { scheduler_manage_port_ = sched_port; }
  332. uint16_t PSContext::scheduler_manage_port() const { return scheduler_manage_port_; }
  333. void PSContext::set_config_file_path(const std::string &path) { config_file_path_ = path; }
  334. std::string PSContext::config_file_path() const { return config_file_path_; }
  335. void PSContext::set_node_id(const std::string &node_id) { node_id_ = node_id; }
  336. const std::string &PSContext::node_id() const { return node_id_; }
  337. } // namespace ps
  338. } // namespace mindspore