You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parameter_server.h 10 kB

5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_
  17. #define MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_
  18. #include <unistd.h>
  19. #include <string>
  20. #include <iostream>
  21. #include <memory>
  22. #include <vector>
  23. #include <mutex>
  24. #include <condition_variable>
  25. #include <thread>
  26. #include <cmath>
  27. #include <random>
  28. #include <utility>
  29. #include <list>
  30. #include <map>
  31. #include <functional>
  32. #include <algorithm>
  33. #include "utils/hash_map.h"
  34. #include "ir/func_graph.h"
  35. #include "backend/common/session/session_basic.h"
  36. #include "backend/common/session/anf_runtime_algorithm.h"
  37. #include "include/common/utils/anfalgo.h"
  38. #include "backend/common/session/session_factory.h"
  39. #include "ps/optimizer_info.h"
  40. #include "ps/optimizer_info_builder.h"
  41. #include "ps/ps_context.h"
  42. #include "plugin/device/cpu/hal/device/kernel_select_cpu.h"
  43. #include "utils/ms_context.h"
  44. #include "kernel/kernel.h"
  45. #include "plugin/device/cpu/kernel/cpu_kernel_factory.h"
  46. #include "plugin/device/cpu/kernel/ps/pserver_kernel.h"
  47. #include "plugin/device/cpu/kernel/ps/sparse_apply_adam_ps_kernel.h"
  48. #include "plugin/device/cpu/kernel/ps/sparse_apply_lazy_adam_ps_kernel.h"
  49. #include "plugin/device/cpu/kernel/ps/sparse_apply_ftrl_ps_kernel.h"
  50. #include "plugin/device/cpu/kernel/ps/apply_momentum_ps_kernel.h"
  51. #include "plugin/device/cpu/kernel/ps/embedding_look_up_ps_kernel.h"
  52. #include "ps/ps_cache/ps_data/ps_data_prefetch.h"
  53. #include "ps/random_normal/random_normal.h"
  54. #include "distributed/persistent/data.h"
  55. #include "ps/constants.h"
  56. #include "ps/util.h"
  57. #include "ps/embedding_table_shard_metadata.h"
  58. #include "utils/log_adapter.h"
  59. #include "proto/comm.pb.h"
  60. #include "proto/ps.pb.h"
  61. #include "ps/core/ps_server_node.h"
  62. #include "ps/core/node.h"
  63. namespace mindspore {
  64. namespace ps {
  65. class ParameterServer {
  66. public:
  67. static ParameterServer &GetInstance() {
  68. static ParameterServer instance;
  69. return instance;
  70. }
  71. void Run(const FuncGraphPtr &func_graph);
  72. private:
  73. ParameterServer()
  74. : pserver_num_(0),
  75. worker_num_(0),
  76. grad_accum_count_(0),
  77. handler_(nullptr),
  78. func_graph_(nullptr),
  79. sess_(nullptr),
  80. running_(true),
  81. thread_(nullptr),
  82. persist_thread_(nullptr),
  83. server_node_(nullptr) {}
  84. ~ParameterServer() = default;
  85. ParameterServer(const ParameterServer &) = delete;
  86. ParameterServer &operator=(const ParameterServer &) = delete;
  87. class ServerHandler {
  88. public:
  89. explicit ServerHandler(ParameterServer *ps) : ps_(ps) {}
  90. ~ServerHandler() = default;
  91. void Init();
  92. void operator()(const std::shared_ptr<core::TcpConnection> &conn, const std::shared_ptr<core::MessageMeta> &meta,
  93. const void *data, size_t size);
  94. void HandlePushReq(const void *data, size_t size, const VectorPtr &res);
  95. void HandlePullReq(const void *data, size_t size, const VectorPtr &res);
  96. void HandleInitWeights(const void *data, size_t size, const VectorPtr &res);
  97. void HandleInitWeightToOptimId(const void *data, size_t size, const VectorPtr &res);
  98. void HandleInitInputsShape(const void *data, size_t size, const VectorPtr &res);
  99. void HandleInitEmbeddings(const void *data, size_t size, const VectorPtr &res);
  100. void HandleCheckReadyForPush(const void *data, size_t size, const VectorPtr &res);
  101. void HandleCheckReadyForPull(const void *data, size_t size, const VectorPtr &res);
  102. void HandleEmbeddingLookup(const void *data, size_t size, const VectorPtr &res);
  103. void HandleUpdateEmbeddings(const void *data, size_t size, const VectorPtr &res);
  104. void HandleFinalize(const void *data, size_t size, const VectorPtr &res);
  105. private:
  106. ParameterServer *ps_;
  107. typedef void (ServerHandler::*RequestHandler)(const void *data, size_t size, const VectorPtr &res);
  108. mindspore::HashMap<int, RequestHandler> handlers_;
  109. mindspore::HashMap<int, std::string> commands_;
  110. mindspore::HashMap<Key, bool> init_weights_;
  111. mindspore::HashMap<Key, bool> init_weight_to_optim_;
  112. mindspore::HashMap<Key, bool> init_optim_info_;
  113. };
  114. // For disaster recovery, you can customize the key-value structure that needs to be persisted, and you can customize
  115. // the business layer disaster recovery function.
  116. class RecoverHandler {
  117. public:
  118. explicit RecoverHandler(ParameterServer *ps) : ps_(ps) {}
  119. ~RecoverHandler() = default;
  120. // Initialize storage module and file storage is currently used.
  121. void Init();
  122. // Do disaster recovery.
  123. void Recover();
  124. core::FileConfiguration *config_storage() const { return storage_.get(); }
  125. private:
  126. // Load embedding information from persistent storage to recover embedding table.
  127. void RecoverEmbedding();
  128. ParameterServer *ps_;
  129. typedef void (RecoverHandler::*RecoverFunc)();
  130. mindspore::HashMap<std::string, RecoverFunc> handlers_;
  131. std::unique_ptr<core::FileConfiguration> storage_{nullptr};
  132. };
  133. bool Init(const FuncGraphPtr &func_graph);
  134. void InitOptimInfoBuilders();
  135. void InitWeightKeyToOptims(const Key &key, const int64_t &optim_id);
  136. void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths);
  137. void InitWeight(const Key &key, const WeightPtr &weight);
  138. void InitGrad(const Key &key, const GradPtr &grad);
  139. void InitEmbeddingTable(const Key &key,
  140. const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
  141. const ParamInitInfo &param_init_info);
  142. bool HasWeight(const Key &key);
  143. void Finalize();
  144. void UpdateWeights();
  145. void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths);
  146. WeightPtr weight(const Key &key);
  147. void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res);
  148. void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals);
  149. inline bool ReadyForUpdateWeights() const;
  150. inline bool ReadyForPush(const Key &key);
  151. inline bool ReadyForPull(const Key &key);
  152. inline void ResetGradAccumCount();
  153. const CNodePtr GetCNode(const std::string &name) const;
  154. inline std::mutex &mutex();
  155. void GetEmbeddingTableParamPtr();
  156. void SyncEmbeddingTables();
  157. // Cache embedding table parameter by map, key: parameter name, value: parameter node pointer
  158. void CacheEmbeddingTableParamPtr();
  159. // Whether enable disaster recovery.
  160. bool EnableRecovery() const;
  161. // Persist weight periodically, trigger by scheduler.
  162. void PersistParameters();
  163. // Persist sparse network operators when receive init embedding table message.
  164. void PersistKernels(const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
  165. const ParamInitInfo &param_init_info) const;
  166. // Persist parameters store in parameter server when receive init message.
  167. void PersistInitParameters(const Key &key, const WeightPtr &param);
  168. // Restore sparse network operators and parameters.
  169. void RecoverEmbedding(const std::vector<Key> &keys, const std::vector<std::vector<std::vector<size_t>>> &shapes_list,
  170. const std::vector<std::string> &param_names);
  171. // Restore sparse network operators.
  172. void RecoverKernels(const std::vector<Key> &keys, const std::vector<std::vector<std::vector<size_t>>> &shapes_list,
  173. const std::vector<std::string> &param_names);
  174. // Restore parameters store in parameter server.
  175. void RecoverParameters(const std::vector<Key> &keys);
  176. // Update the indices of modified part of the persistent parameter.
  177. void UpdateDirtyInfo(const Key &key, const LookupIds &lookup_ids, int64_t offset);
  178. // Ser current persistent state to server node.
  179. void set_persistent_state(core::PersistentState persistent_state) const;
  180. std::unique_ptr<RecoverHandler> recover_handler_;
  181. std::atomic_bool finish_recovery_{false};
  182. size_t pserver_num_;
  183. size_t worker_num_;
  184. size_t grad_accum_count_;
  185. std::unique_ptr<ServerHandler> handler_;
  186. FuncGraphPtr func_graph_;
  187. std::shared_ptr<session::SessionBasic> sess_;
  188. bool running_;
  189. bool embedding_param_ptr_cached_{false};
  190. // Used to cache embedding table parameter, key: parameter name, value: parameter node pointer
  191. mindspore::HashMap<std::string, ParameterPtr> embedding_parameter_tables_;
  192. // Used to cache the modified part of the parameter.
  193. mindspore::HashMap<Key, distributed::storage::DirtyInfo> weights_dirty_info_;
  194. mindspore::HashMap<Key, std::shared_ptr<PServerKernel>> optimizers_;
  195. mindspore::HashMap<Key, InputsShapePtr> optim_inputs_shape_;
  196. mindspore::HashMap<Key, InputsShapePtr> original_optim_inputs_shape_;
  197. mindspore::HashMap<Key, std::shared_ptr<OptimizerInfo>> optim_infos_;
  198. mindspore::HashMap<std::string, std::shared_ptr<OptimizerInfoBuilder>> optim_info_builders_;
  199. mindspore::HashMap<Key, std::string> weight_key_to_optims_;
  200. mindspore::HashMap<Key, std::string> weight_key_to_optim_op_;
  201. mindspore::HashMap<Key, WeightPtr> weights_;
  202. mindspore::HashMap<Key, bool> is_embedding_;
  203. mindspore::HashMap<Key, GradPtr> grads_;
  204. mindspore::HashMap<Key, size_t> grads_accum_counter_;
  205. mindspore::HashMap<Key, std::shared_ptr<PServerKernel>> embedding_lookup_ops_;
  206. mindspore::HashMap<Key, uint64_t> tokens_;
  207. std::mutex mutex_;
  208. std::condition_variable apply_grads_cv_;
  209. std::mutex access_weight_mutex_;
  210. std::unique_ptr<std::thread> thread_;
  211. std::unique_ptr<std::thread> persist_thread_;
  212. std::shared_ptr<core::ServerNode> server_node_;
  213. std::map<Key, ParameterPtr> embedding_tables_;
  214. friend class ServerHandler;
  215. };
  216. } // namespace ps
  217. } // namespace mindspore
  218. #endif // MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_