You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ps_context.cc 18 kB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ps/ps_context.h"
  17. #include "utils/log_adapter.h"
  18. #include "utils/ms_utils.h"
  19. #include "backend/kernel_compiler/kernel.h"
  20. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  21. #include "ps/ps_cache/ps_cache_manager.h"
  22. #include "ps/ps_cache/ps_data/ps_data_prefetch.h"
  23. #include "distributed/cluster/cluster_context.h"
  24. #else
  25. #include "distributed/cluster/dummy_cluster_context.h"
  26. #endif
  27. namespace mindspore {
  28. namespace ps {
  29. std::shared_ptr<PSContext> PSContext::instance() {
  30. static std::shared_ptr<PSContext> ps_instance = nullptr;
  31. if (ps_instance == nullptr) {
  32. ps_instance.reset(new PSContext());
  33. }
  34. return ps_instance;
  35. }
  36. void PSContext::SetPSEnable(bool enabled) {
  37. ps_enabled_ = enabled;
  38. if (ps_enabled_) {
  39. std::string ms_role = common::GetEnv(kEnvRole);
  40. MS_LOG(INFO) << "PS mode is enabled. MS_ROLE is " << ms_role;
  41. if (ms_role == kEnvRoleOfWorker) {
  42. is_worker_ = true;
  43. } else if (ms_role == kEnvRoleOfPServer) {
  44. is_pserver_ = true;
  45. } else if (ms_role == kEnvRoleOfScheduler) {
  46. is_sched_ = true;
  47. } else {
  48. MS_LOG(INFO) << "MS_ROLE is " << ms_role;
  49. }
  50. worker_num_ = std::strtol(common::GetEnv(kEnvWorkerNum).c_str(), nullptr, kBase);
  51. server_num_ = std::strtol(common::GetEnv(kEnvPServerNum).c_str(), nullptr, kBase);
  52. scheduler_host_ = common::GetEnv(kEnvSchedulerHost);
  53. if (scheduler_host_.length() > kLength) {
  54. MS_LOG(EXCEPTION) << "The scheduler host's length can not exceed " << kLength;
  55. }
  56. scheduler_port_ = std::strtol(common::GetEnv(kEnvSchedulerPort).c_str(), nullptr, kBase);
  57. if (scheduler_port_ > kMaxPort) {
  58. MS_LOG(EXCEPTION) << "The port: " << scheduler_port_ << " is illegal.";
  59. }
  60. scheduler_manage_port_ =
  61. static_cast<uint16_t>((std::strtol(common::GetEnv(kEnvSchedulerManagePort).c_str(), nullptr, kBase)));
  62. if (scheduler_manage_port_ > kMaxPort) {
  63. MS_LOG(EXCEPTION) << "The port << " << scheduler_manage_port_ << " is illegal.";
  64. }
  65. cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
  66. node_id_ = common::GetEnv(kEnvNodeId);
  67. if (node_id_.length() > kLength) {
  68. MS_LOG(EXCEPTION) << "The node id length can not exceed " << kLength;
  69. }
  70. } else {
  71. MS_LOG(INFO) << "PS mode is disabled.";
  72. is_worker_ = false;
  73. is_pserver_ = false;
  74. is_sched_ = false;
  75. }
  76. }
  77. bool PSContext::is_ps_mode() const {
  78. if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
  79. return true;
  80. }
  81. return ps_enabled_;
  82. }
  83. void PSContext::Reset() {
  84. ps_enabled_ = false;
  85. is_worker_ = false;
  86. is_pserver_ = false;
  87. is_sched_ = false;
  88. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  89. if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
  90. ps_cache_instance.Finalize();
  91. set_cache_enable(false);
  92. }
  93. #endif
  94. }
  95. std::string PSContext::ms_role() const {
  96. if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
  97. return role_;
  98. }
  99. if (is_worker_) {
  100. return kEnvRoleOfWorker;
  101. } else if (is_pserver_) {
  102. return kEnvRoleOfPServer;
  103. } else if (is_sched_) {
  104. return kEnvRoleOfScheduler;
  105. } else {
  106. return kEnvRoleOfNotPS;
  107. }
  108. }
  109. bool PSContext::is_worker() const {
  110. if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
  111. return role_ == kEnvRoleOfWorker;
  112. }
  113. if (distributed::cluster::ClusterContext::instance()->initialized()) {
  114. return role_ == kEnvRoleOfWorker;
  115. }
  116. return is_worker_;
  117. }
  118. bool PSContext::is_server() const {
  119. if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
  120. return role_ == kEnvRoleOfServer;
  121. }
  122. if (distributed::cluster::ClusterContext::instance()->initialized()) {
  123. return role_ == kEnvRoleOfServer;
  124. }
  125. return is_pserver_;
  126. }
  127. bool PSContext::is_scheduler() const {
  128. if ((server_mode_ == kServerModeFL || server_mode_ == kServerModeHybrid) && ps_enabled_) {
  129. return role_ == kEnvRoleOfScheduler;
  130. }
  131. if (distributed::cluster::ClusterContext::instance()->initialized()) {
  132. return role_ == kEnvRoleOfScheduler;
  133. }
  134. return is_sched_;
  135. }
  136. uint32_t PSContext::initial_worker_num() const { return worker_num_; }
  137. uint32_t PSContext::initial_server_num() const { return server_num_; }
  138. std::string PSContext::scheduler_host() const { return scheduler_host_; }
  139. void PSContext::SetPSRankId(uint32_t rank_id) { rank_id_ = rank_id; }
  140. uint32_t PSContext::ps_rank_id() const { return rank_id_; }
  141. void PSContext::InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
  142. size_t vocab_size) const {
  143. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  144. ps_cache_instance.InsertHashTableSize(param_name, cache_vocab_size, embedding_size, vocab_size);
  145. #endif
  146. }
  147. void PSContext::ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
  148. size_t cache_vocab_size, size_t embedding_size) const {
  149. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  150. ps_cache_instance.ReInsertHashTableSize(new_param_name, cur_param_name, cache_vocab_size, embedding_size);
  151. #endif
  152. }
  153. void PSContext::InsertWeightInitInfo(const std::string &param_name, size_t global_seed, size_t op_seed) const {
  154. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  155. ps_cache_instance.InsertWeightInitInfo(param_name, global_seed, op_seed);
  156. #endif
  157. }
  158. void PSContext::InsertAccumuInitInfo(const std::string &param_name, float init_val) const {
  159. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  160. ps_cache_instance.InsertAccumuInitInfo(param_name, init_val);
  161. #endif
  162. }
  163. void PSContext::CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const {
  164. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  165. ps_cache_instance.CloneHashTable(dest_param_name, src_param_name);
  166. #endif
  167. }
  168. void PSContext::set_cache_enable(bool cache_enable) const {
  169. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  170. PsDataPrefetch::GetInstance().set_cache_enable(cache_enable);
  171. #endif
  172. }
  173. void PSContext::set_rank_id(uint32_t rank_id) const {
  174. #if ((defined ENABLE_CPU) && (!defined _WIN32))
  175. ps_cache_instance.set_rank_id(rank_id);
  176. #endif
  177. }
  178. void PSContext::set_server_mode(const std::string &server_mode) {
  179. if (server_mode != kServerModePS && server_mode != kServerModeFL && server_mode != kServerModeHybrid) {
  180. MS_LOG(EXCEPTION) << server_mode << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
  181. << " or " << kServerModeHybrid;
  182. return;
  183. }
  184. MS_LOG(INFO) << "Server mode: " << server_mode << " is used for Server and Worker. Scheduler will ignore it.";
  185. server_mode_ = server_mode;
  186. }
  187. const std::string &PSContext::server_mode() const { return server_mode_; }
  188. void PSContext::set_encrypt_type(const std::string &encrypt_type) {
  189. if (encrypt_type != kNotEncryptType && encrypt_type != kDPEncryptType && encrypt_type != kPWEncryptType &&
  190. encrypt_type != kStablePWEncryptType) {
  191. MS_LOG(EXCEPTION) << encrypt_type << " is invalid. Encrypt type must be " << kNotEncryptType << " or "
  192. << kDPEncryptType << " or " << kPWEncryptType << " or " << kStablePWEncryptType;
  193. return;
  194. }
  195. encrypt_type_ = encrypt_type;
  196. }
  197. const std::string &PSContext::encrypt_type() const { return encrypt_type_; }
  198. void PSContext::set_dp_eps(float dp_eps) {
  199. if (dp_eps > 0) {
  200. dp_eps_ = dp_eps;
  201. } else {
  202. MS_LOG(EXCEPTION) << dp_eps << " is invalid, dp_eps must be larger than 0.";
  203. return;
  204. }
  205. }
  206. float PSContext::dp_eps() const { return dp_eps_; }
  207. void PSContext::set_dp_delta(float dp_delta) {
  208. if (dp_delta > 0 && dp_delta < 1) {
  209. dp_delta_ = dp_delta;
  210. } else {
  211. MS_LOG(EXCEPTION) << dp_delta << " is invalid, dp_delta must be in range of (0, 1).";
  212. return;
  213. }
  214. }
  215. float PSContext::dp_delta() const { return dp_delta_; }
  216. void PSContext::set_dp_norm_clip(float dp_norm_clip) {
  217. if (dp_norm_clip > 0) {
  218. dp_norm_clip_ = dp_norm_clip;
  219. } else {
  220. MS_LOG(EXCEPTION) << dp_norm_clip << " is invalid, dp_norm_clip must be larger than 0.";
  221. return;
  222. }
  223. }
  224. float PSContext::dp_norm_clip() const { return dp_norm_clip_; }
  225. void PSContext::set_ms_role(const std::string &role) {
  226. if (role != kEnvRoleOfWorker && role != kEnvRoleOfServer && role != kEnvRoleOfScheduler) {
  227. MS_LOG(EXCEPTION) << "ms_role " << role << " is invalid.";
  228. return;
  229. }
  230. role_ = role;
  231. }
  232. void PSContext::set_worker_num(uint32_t worker_num) {
  233. // Hybrid training mode only supports one worker for now.
  234. if (server_mode_ == kServerModeHybrid && worker_num != 1) {
  235. MS_LOG(EXCEPTION) << "The worker number should be set to 1 in hybrid training mode.";
  236. return;
  237. }
  238. worker_num_ = worker_num;
  239. }
  240. uint32_t PSContext::worker_num() const { return worker_num_; }
  241. void PSContext::set_server_num(uint32_t server_num) { server_num_ = server_num; }
  242. uint32_t PSContext::server_num() const { return server_num_; }
  243. void PSContext::set_scheduler_ip(const std::string &sched_ip) { scheduler_host_ = sched_ip; }
  244. std::string PSContext::scheduler_ip() const { return scheduler_host_; }
  245. void PSContext::set_scheduler_port(uint16_t sched_port) { scheduler_port_ = sched_port; }
  246. uint16_t PSContext::scheduler_port() const { return scheduler_port_; }
  247. void PSContext::GenerateResetterRound() {
  248. uint32_t binary_server_context = 0;
  249. bool is_parameter_server_mode = false;
  250. bool is_federated_learning_mode = false;
  251. bool is_mixed_training_mode = false;
  252. bool is_pairwise_encrypt = (encrypt_type_ == kPWEncryptType);
  253. if (server_mode_ == kServerModePS) {
  254. is_parameter_server_mode = true;
  255. } else if (server_mode_ == kServerModeFL) {
  256. is_federated_learning_mode = true;
  257. } else if (server_mode_ == kServerModeHybrid) {
  258. is_mixed_training_mode = true;
  259. } else {
  260. MS_LOG(EXCEPTION) << server_mode_ << " is invalid. Server mode must be " << kServerModePS << " or " << kServerModeFL
  261. << " or " << kServerModeHybrid;
  262. return;
  263. }
  264. const int training_mode_offset = 2;
  265. const int pairwise_encrypt_offset = 3;
  266. binary_server_context = ((unsigned int)is_parameter_server_mode) | ((unsigned int)is_federated_learning_mode << 1) |
  267. ((unsigned int)is_mixed_training_mode << training_mode_offset) |
  268. ((unsigned int)is_pairwise_encrypt << pairwise_encrypt_offset);
  269. if (kServerContextToResetRoundMap.count(binary_server_context) == 0) {
  270. resetter_round_ = ResetterRound::kNoNeedToReset;
  271. } else {
  272. resetter_round_ = kServerContextToResetRoundMap.at(binary_server_context);
  273. }
  274. MS_LOG(INFO) << "Server context is " << binary_server_context << ". Resetter round is " << resetter_round_;
  275. return;
  276. }
  277. ResetterRound PSContext::resetter_round() const { return resetter_round_; }
  278. void PSContext::set_fl_server_port(uint16_t fl_server_port) { fl_server_port_ = fl_server_port; }
  279. uint16_t PSContext::fl_server_port() const { return fl_server_port_; }
  280. void PSContext::set_fl_client_enable(bool enabled) { fl_client_enable_ = enabled; }
  281. bool PSContext::fl_client_enable() const { return fl_client_enable_; }
  282. void PSContext::set_start_fl_job_threshold(uint64_t start_fl_job_threshold) {
  283. start_fl_job_threshold_ = start_fl_job_threshold;
  284. }
  285. uint64_t PSContext::start_fl_job_threshold() const { return start_fl_job_threshold_; }
  286. void PSContext::set_start_fl_job_time_window(uint64_t start_fl_job_time_window) {
  287. start_fl_job_time_window_ = start_fl_job_time_window;
  288. }
  289. uint64_t PSContext::start_fl_job_time_window() const { return start_fl_job_time_window_; }
  290. void PSContext::set_update_model_ratio(float update_model_ratio) {
  291. if (update_model_ratio > 1.0) {
  292. MS_LOG(EXCEPTION) << "update_model_ratio must be between 0 and 1.";
  293. return;
  294. }
  295. update_model_ratio_ = update_model_ratio;
  296. }
  297. float PSContext::update_model_ratio() const { return update_model_ratio_; }
  298. void PSContext::set_update_model_time_window(uint64_t update_model_time_window) {
  299. update_model_time_window_ = update_model_time_window;
  300. }
  301. uint64_t PSContext::update_model_time_window() const { return update_model_time_window_; }
  302. void PSContext::set_share_secrets_ratio(float share_secrets_ratio) {
  303. if (share_secrets_ratio > 0 && share_secrets_ratio <= 1) {
  304. share_secrets_ratio_ = share_secrets_ratio;
  305. } else {
  306. MS_LOG(EXCEPTION) << share_secrets_ratio << " is invalid, share_secrets_ratio must be in range of (0, 1].";
  307. return;
  308. }
  309. }
  310. float PSContext::share_secrets_ratio() const { return share_secrets_ratio_; }
  311. void PSContext::set_cipher_time_window(uint64_t cipher_time_window) {
  312. if (cipher_time_window_ < 0) {
  313. MS_LOG(EXCEPTION) << "cipher_time_window should not be less than 0.";
  314. return;
  315. }
  316. cipher_time_window_ = cipher_time_window;
  317. }
  318. uint64_t PSContext::cipher_time_window() const { return cipher_time_window_; }
  319. void PSContext::set_reconstruct_secrets_threshold(uint64_t reconstruct_secrets_threshold) {
  320. if (reconstruct_secrets_threshold == 0) {
  321. MS_LOG(EXCEPTION) << "reconstruct_secrets_threshold should be positive.";
  322. return;
  323. }
  324. reconstruct_secrets_threshold_ = reconstruct_secrets_threshold;
  325. }
  326. uint64_t PSContext::reconstruct_secrets_threshold() const { return reconstruct_secrets_threshold_; }
  327. void PSContext::set_fl_name(const std::string &fl_name) { fl_name_ = fl_name; }
  328. const std::string &PSContext::fl_name() const { return fl_name_; }
  329. void PSContext::set_fl_iteration_num(uint64_t fl_iteration_num) { fl_iteration_num_ = fl_iteration_num; }
  330. uint64_t PSContext::fl_iteration_num() const { return fl_iteration_num_; }
  331. void PSContext::set_client_epoch_num(uint64_t client_epoch_num) { client_epoch_num_ = client_epoch_num; }
  332. uint64_t PSContext::client_epoch_num() const { return client_epoch_num_; }
  333. void PSContext::set_client_batch_size(uint64_t client_batch_size) { client_batch_size_ = client_batch_size; }
  334. uint64_t PSContext::client_batch_size() const { return client_batch_size_; }
  335. void PSContext::set_client_learning_rate(float client_learning_rate) { client_learning_rate_ = client_learning_rate; }
  336. float PSContext::client_learning_rate() const { return client_learning_rate_; }
  337. void PSContext::set_worker_step_num_per_iteration(uint64_t worker_step_num_per_iteration) {
  338. worker_step_num_per_iteration_ = worker_step_num_per_iteration;
  339. }
  340. uint64_t PSContext::worker_step_num_per_iteration() const { return worker_step_num_per_iteration_; }
  341. void PSContext::set_secure_aggregation(bool secure_aggregation) { secure_aggregation_ = secure_aggregation; }
  342. bool PSContext::secure_aggregation() const { return secure_aggregation_; }
  343. core::ClusterConfig &PSContext::cluster_config() {
  344. if (cluster_config_ == nullptr) {
  345. cluster_config_ = std::make_unique<core::ClusterConfig>(worker_num_, server_num_, scheduler_host_, scheduler_port_);
  346. MS_EXCEPTION_IF_NULL(cluster_config_);
  347. }
  348. return *cluster_config_;
  349. }
  350. void PSContext::set_root_first_ca_path(const std::string &root_first_ca_path) {
  351. root_first_ca_path_ = root_first_ca_path;
  352. }
  353. void PSContext::set_root_second_ca_path(const std::string &root_second_ca_path) {
  354. root_second_ca_path_ = root_second_ca_path;
  355. }
  356. std::string PSContext::root_first_ca_path() const { return root_first_ca_path_; }
  357. std::string PSContext::root_second_ca_path() const { return root_second_ca_path_; }
  358. void PSContext::set_pki_verify(bool pki_verify) { pki_verify_ = pki_verify; }
  359. bool PSContext::pki_verify() const { return pki_verify_; }
  360. void PSContext::set_replay_attack_time_diff(uint64_t replay_attack_time_diff) {
  361. replay_attack_time_diff_ = replay_attack_time_diff;
  362. }
  363. uint64_t PSContext::replay_attack_time_diff() const { return replay_attack_time_diff_; }
  364. std::string PSContext::equip_crl_path() const { return equip_crl_path_; }
  365. void PSContext::set_equip_crl_path(const std::string &equip_crl_path) { equip_crl_path_ = equip_crl_path; }
  366. void PSContext::set_scheduler_manage_port(uint16_t sched_port) { scheduler_manage_port_ = sched_port; }
  367. uint16_t PSContext::scheduler_manage_port() const { return scheduler_manage_port_; }
  368. void PSContext::set_config_file_path(const std::string &path) { config_file_path_ = path; }
  369. std::string PSContext::config_file_path() const { return config_file_path_; }
  370. void PSContext::set_node_id(const std::string &node_id) { node_id_ = node_id; }
  371. const std::string &PSContext::node_id() const { return node_id_; }
  372. bool PSContext::enable_ssl() const { return enable_ssl_; }
  373. void PSContext::set_enable_ssl(bool enabled) { enable_ssl_ = enabled; }
  374. std::string PSContext::client_password() const { return client_password_; }
  375. void PSContext::set_client_password(const std::string &password) { client_password_ = password; }
  376. std::string PSContext::server_password() const { return server_password_; }
  377. void PSContext::set_server_password(const std::string &password) { server_password_ = password; }
  378. std::string PSContext::http_url_prefix() const { return http_url_prefix_; }
  379. void PSContext::set_http_url_prefix(const std::string &http_url_prefix) { http_url_prefix_ = http_url_prefix; }
  380. } // namespace ps
  381. } // namespace mindspore