You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ps_context.h 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PS_CONTEXT_H_
  17. #define MINDSPORE_CCSRC_PS_CONTEXT_H_
  18. #include <map>
  19. #include <string>
  20. #include <memory>
  21. #include "ps/constants.h"
  22. #include "ps/core/cluster_metadata.h"
  23. #include "ps/core/cluster_config.h"
  24. namespace mindspore {
  25. namespace ps {
  26. constexpr char kServerModePS[] = "PARAMETER_SERVER";
  27. constexpr char kServerModeFL[] = "FEDERATED_LEARNING";
  28. constexpr char kServerModeHybrid[] = "HYBRID_TRAINING";
  29. constexpr char kEnvRole[] = "MS_ROLE";
  30. constexpr char kEnvRoleOfPServer[] = "MS_PSERVER";
  31. constexpr char kEnvRoleOfServer[] = "MS_SERVER";
  32. constexpr char kEnvRoleOfWorker[] = "MS_WORKER";
  33. constexpr char kEnvRoleOfScheduler[] = "MS_SCHED";
  34. constexpr char kEnvRoleOfNotPS[] = "MS_NOT_PS";
  35. constexpr char kDPEncryptType[] = "DP_ENCRYPT";
  36. constexpr char kPWEncryptType[] = "PW_ENCRYPT";
  37. constexpr char kStablePWEncryptType[] = "STABLE_PW_ENCRYPT";
  38. constexpr char kNotEncryptType[] = "NOT_ENCRYPT";
  39. // Use binary data to represent federated learning server's context so that we can judge which round resets the
  40. // iteration. From right to left, each bit stands for:
  41. // 0: Server is in parameter server mode.
  42. // 1: Server is in federated learning mode.
  43. // 2: Server is in mixed training mode.
  44. // 3: Server enables pairwise encrypt algorithm.
  45. // For example: 1010 stands for that the server is in federated learning mode and pairwise encrypt algorithm is enabled.
  46. enum class ResetterRound { kNoNeedToReset, kUpdateModel, kReconstructSeccrets, kPushWeight, kPushMetrics };
  47. const std::map<uint32_t, ResetterRound> kServerContextToResetRoundMap = {{0b0010, ResetterRound::kUpdateModel},
  48. {0b1010, ResetterRound::kReconstructSeccrets},
  49. {0b1100, ResetterRound::kPushMetrics},
  50. {0b0100, ResetterRound::kPushMetrics}};
  51. class PSContext {
  52. public:
  53. ~PSContext() = default;
  54. PSContext(PSContext const &) = delete;
  55. PSContext &operator=(const PSContext &) = delete;
  56. static std::shared_ptr<PSContext> instance();
  57. void SetPSEnable(bool enabled);
  58. bool is_ps_mode() const;
  59. void Reset();
  60. std::string ms_role() const;
  61. bool is_worker() const;
  62. bool is_server() const;
  63. bool is_scheduler() const;
  64. uint32_t initial_worker_num() const;
  65. uint32_t initial_server_num() const;
  66. std::string scheduler_host() const;
  67. void SetPSRankId(uint32_t rank_id);
  68. uint32_t ps_rank_id() const;
  69. void InsertHashTableSize(const std::string &param_name, size_t cache_vocab_size, size_t embedding_size,
  70. size_t vocab_size) const;
  71. void ReInsertHashTableSize(const std::string &new_param_name, const std::string &cur_param_name,
  72. size_t cache_vocab_size, size_t embedding_size) const;
  73. void InsertWeightInitInfo(const std::string &param_name, size_t global_seed, size_t op_seed) const;
  74. void InsertAccumuInitInfo(const std::string &param_name, float init_val) const;
  75. void CloneHashTable(const std::string &dest_param_name, const std::string &src_param_name) const;
  76. void set_cache_enable(bool cache_enable) const;
  77. void set_rank_id(uint32_t rank_id) const;
  78. bool enable_ssl() const;
  79. void set_enable_ssl(bool enabled);
  80. std::string client_password() const;
  81. void set_client_password(const std::string &password);
  82. std::string server_password() const;
  83. void set_server_password(const std::string &password);
  84. // In new server framework, process role, worker number, server number, scheduler ip and scheduler port should be set
  85. // by ps_context.
  86. void set_server_mode(const std::string &server_mode);
  87. const std::string &server_mode() const;
  88. void set_ms_role(const std::string &role);
  89. void set_worker_num(uint32_t worker_num);
  90. uint32_t worker_num() const;
  91. void set_server_num(uint32_t server_num);
  92. uint32_t server_num() const;
  93. void set_scheduler_ip(const std::string &sched_ip);
  94. std::string scheduler_ip() const;
  95. void set_scheduler_port(uint16_t sched_port);
  96. uint16_t scheduler_port() const;
  97. // Methods federated learning.
  98. // Generate which round should reset the iteration.
  99. void GenerateResetterRound();
  100. ResetterRound resetter_round() const;
  101. void set_fl_server_port(uint16_t fl_server_port);
  102. uint16_t fl_server_port() const;
  103. // Set true if this process is a federated learning worker in cross-silo scenario.
  104. void set_fl_client_enable(bool enabled);
  105. bool fl_client_enable() const;
  106. void set_start_fl_job_threshold(uint64_t start_fl_job_threshold);
  107. uint64_t start_fl_job_threshold() const;
  108. void set_start_fl_job_time_window(uint64_t start_fl_job_time_window);
  109. uint64_t start_fl_job_time_window() const;
  110. void set_update_model_ratio(float update_model_ratio);
  111. float update_model_ratio() const;
  112. void set_update_model_time_window(uint64_t update_model_time_window);
  113. uint64_t update_model_time_window() const;
  114. void set_share_secrets_ratio(float share_secrets_ratio);
  115. float share_secrets_ratio() const;
  116. void set_cipher_time_window(uint64_t cipher_time_window);
  117. uint64_t cipher_time_window() const;
  118. void set_reconstruct_secrets_threshold(uint64_t reconstruct_secrets_threshold);
  119. uint64_t reconstruct_secrets_threshold() const;
  120. void set_fl_name(const std::string &fl_name);
  121. const std::string &fl_name() const;
  122. // Set the iteration number of the federated learning.
  123. void set_fl_iteration_num(uint64_t fl_iteration_num);
  124. uint64_t fl_iteration_num() const;
  125. // Set the training epoch number of the client.
  126. void set_client_epoch_num(uint64_t client_epoch_num);
  127. uint64_t client_epoch_num() const;
  128. // Set the data batch size of the client.
  129. void set_client_batch_size(uint64_t client_batch_size);
  130. uint64_t client_batch_size() const;
  131. void set_client_learning_rate(float client_learning_rate);
  132. float client_learning_rate() const;
  133. void set_worker_step_num_per_iteration(uint64_t worker_step_num_per_iteration);
  134. uint64_t worker_step_num_per_iteration() const;
  135. core::ClusterConfig &cluster_config();
  136. void set_scheduler_manage_port(uint16_t sched_port);
  137. uint16_t scheduler_manage_port() const;
  138. void set_config_file_path(const std::string &path);
  139. std::string config_file_path() const;
  140. void set_dp_eps(float dp_eps);
  141. float dp_eps() const;
  142. void set_dp_delta(float dp_delta);
  143. float dp_delta() const;
  144. void set_dp_norm_clip(float dp_norm_clip);
  145. float dp_norm_clip() const;
  146. void set_encrypt_type(const std::string &encrypt_type);
  147. const std::string &encrypt_type() const;
  148. void set_node_id(const std::string &node_id);
  149. const std::string &node_id() const;
  150. private:
  151. PSContext()
  152. : ps_enabled_(false),
  153. is_worker_(false),
  154. is_pserver_(false),
  155. is_sched_(false),
  156. enable_ssl_(false),
  157. rank_id_(0),
  158. worker_num_(0),
  159. server_num_(0),
  160. scheduler_host_("0.0.0.0"),
  161. scheduler_port_(6667),
  162. role_(kEnvRoleOfNotPS),
  163. server_mode_(""),
  164. resetter_round_(ResetterRound::kNoNeedToReset),
  165. fl_server_port_(6668),
  166. fl_client_enable_(false),
  167. fl_name_(""),
  168. start_fl_job_threshold_(0),
  169. start_fl_job_time_window_(3000),
  170. update_model_ratio_(1.0),
  171. update_model_time_window_(3000),
  172. share_secrets_ratio_(1.0),
  173. cipher_time_window_(300000),
  174. reconstruct_secrets_threshold_(2000),
  175. fl_iteration_num_(20),
  176. client_epoch_num_(25),
  177. client_batch_size_(32),
  178. client_learning_rate_(0.001),
  179. worker_step_num_per_iteration_(65),
  180. secure_aggregation_(false),
  181. cluster_config_(nullptr),
  182. scheduler_manage_port_(11202),
  183. config_file_path_(""),
  184. dp_eps_(50),
  185. dp_delta_(0.01),
  186. dp_norm_clip_(1.0),
  187. encrypt_type_(kNotEncryptType),
  188. node_id_(""),
  189. client_password_(""),
  190. server_password_("") {}
  191. bool ps_enabled_;
  192. bool is_worker_;
  193. bool is_pserver_;
  194. bool is_sched_;
  195. bool enable_ssl_;
  196. uint32_t rank_id_;
  197. uint32_t worker_num_;
  198. uint32_t server_num_;
  199. std::string scheduler_host_;
  200. uint16_t scheduler_port_;
  201. // The server process's role.
  202. std::string role_;
  203. // Server mode which could be Parameter Server, Federated Learning and Hybrid Training mode.
  204. std::string server_mode_;
  205. // The round which will reset the iteration. Used in federated learning for now.
  206. ResetterRound resetter_round_;
  207. // Http port of federated learning server.
  208. uint16_t fl_server_port_;
  209. // Whether this process is the federated client. Used in cross-silo scenario of federated learning.
  210. bool fl_client_enable_;
  211. // Federated learning job name.
  212. std::string fl_name_;
  213. // The threshold count of startFLJob round. Used in federated learning for now.
  214. uint64_t start_fl_job_threshold_;
  215. // The time window of startFLJob round in millisecond.
  216. uint64_t start_fl_job_time_window_;
  217. // Update model threshold is a certain ratio of start_fl_job threshold which is set as update_model_ratio_.
  218. float update_model_ratio_;
  219. // The time window of updateModel round in millisecond.
  220. uint64_t update_model_time_window_;
  221. // Share model threshold is a certain ratio of share secrets threshold which is set as share_secrets_ratio_.
  222. float share_secrets_ratio_;
  223. // The time window of each cipher round in millisecond.
  224. uint64_t cipher_time_window_;
  225. // The threshold count of reconstruct secrets round. Used in federated learning for now.
  226. uint64_t reconstruct_secrets_threshold_;
  227. // Iteration number of federeated learning, which is the number of interactions between client and server.
  228. uint64_t fl_iteration_num_;
  229. // Client training epoch number. Used in federated learning for now.
  230. uint64_t client_epoch_num_;
  231. // Client training data batch size. Used in federated learning for now.
  232. uint64_t client_batch_size_;
  233. // Client training learning rate. Used in federated learning for now.
  234. float client_learning_rate_;
  235. // The worker standalone training step number before communicating with server.
  236. uint64_t worker_step_num_per_iteration_;
  237. // Whether to use secure aggregation algorithm. Used in federated learning for now.
  238. bool secure_aggregation_;
  239. // The cluster config read through environment variables, the value does not change.
  240. std::unique_ptr<core::ClusterConfig> cluster_config_;
  241. // The port used by scheduler to receive http requests for scale out or scale in.
  242. uint16_t scheduler_manage_port_;
  243. // The path of the configuration file, used to configure the certification path and persistent storage type, etc.
  244. std::string config_file_path_;
  245. // Epsilon budget of differential privacy mechanism. Used in federated learning for now.
  246. float dp_eps_;
  247. // Delta budget of differential privacy mechanism. Used in federated learning for now.
  248. float dp_delta_;
  249. // Norm clip factor of differential privacy mechanism. Used in federated learning for now.
  250. float dp_norm_clip_;
  251. // Secure mechanism for federated learning. Used in federated learning for now.
  252. std::string encrypt_type_;
  253. // Unique id of the node
  254. std::string node_id_;
  255. // Password used to decode p12 file.
  256. std::string client_password_;
  257. // Password used to decode p12 file.
  258. std::string server_password_;
  259. };
  260. } // namespace ps
  261. } // namespace mindspore
  262. #endif // MINDSPORE_CCSRC_PS_CONTEXT_H_