You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parameter_server.h 7.2 kB

5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_
  17. #define MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_
  18. #include <unistd.h>
  19. #include <unordered_map>
  20. #include <string>
  21. #include <iostream>
  22. #include <memory>
  23. #include <vector>
  24. #include <mutex>
  25. #include <condition_variable>
  26. #include <thread>
  27. #include <cmath>
  28. #include <random>
  29. #include <utility>
  30. #include <list>
  31. #include <map>
  32. #include <functional>
  33. #include <algorithm>
  34. #include "ir/func_graph.h"
  35. #include "backend/session/session_basic.h"
  36. #include "backend/session/anf_runtime_algorithm.h"
  37. #include "backend/session/session_factory.h"
  38. #include "ps/optimizer_info.h"
  39. #include "ps/optimizer_info_builder.h"
  40. #include "ps/ps_context.h"
  41. #include "runtime/device/cpu/kernel_select_cpu.h"
  42. #include "utils/ms_context.h"
  43. #include "backend/kernel_compiler/kernel.h"
  44. #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
  45. #include "backend/kernel_compiler/cpu/ps/pserver_kernel.h"
  46. #include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h"
  47. #include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h"
  48. #include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h"
  49. #include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h"
  50. #include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h"
  51. #include "ps/ps_cache/ps_data/ps_data_prefetch.h"
  52. #include "ps/random_normal/random_normal.h"
  53. #include "ps/constants.h"
  54. #include "ps/util.h"
  55. #include "ps/embedding_table_shard_metadata.h"
  56. #include "utils/log_adapter.h"
  57. #include "proto/comm.pb.h"
  58. #include "proto/ps.pb.h"
  59. #include "ps/core/server_node.h"
  60. #include "ps/core/node.h"
  61. namespace mindspore {
  62. namespace ps {
  63. class ParameterServer {
  64. public:
  65. static ParameterServer &GetInstance() {
  66. static ParameterServer instance;
  67. return instance;
  68. }
  69. void Run(const FuncGraphPtr &func_graph);
  70. private:
  71. ParameterServer()
  72. : pserver_num_(0),
  73. worker_num_(0),
  74. grad_accum_count_(0),
  75. handler_(nullptr),
  76. func_graph_(nullptr),
  77. sess_(nullptr),
  78. running_(true),
  79. thread_(nullptr),
  80. server_node_(nullptr) {}
  81. ~ParameterServer() = default;
  82. ParameterServer(const ParameterServer &) = delete;
  83. ParameterServer &operator=(const ParameterServer &) = delete;
  84. class ServerHandler {
  85. public:
  86. explicit ServerHandler(ParameterServer *ps) : ps_(ps) {}
  87. ~ServerHandler() = default;
  88. void Init();
  89. void operator()(const std::shared_ptr<core::TcpConnection> &conn, const std::shared_ptr<core::MessageMeta> &meta,
  90. const DataPtr &data, size_t size);
  91. void HandlePushReq(const DataPtr &data, size_t size, const VectorPtr &res);
  92. void HandlePullReq(const DataPtr &data, size_t size, const VectorPtr &res);
  93. void HandleInitWeights(const DataPtr &data, size_t size, const VectorPtr &res);
  94. void HandleInitWeightToOptimId(const DataPtr &data, size_t size, const VectorPtr &res);
  95. void HandleInitInputsShape(const DataPtr &data, size_t size, const VectorPtr &res);
  96. void HandleInitEmbeddings(const DataPtr &data, size_t size, const VectorPtr &res);
  97. void HandleCheckReadyForPush(const DataPtr &data, size_t size, const VectorPtr &res);
  98. void HandleCheckReadyForPull(const DataPtr &data, size_t size, const VectorPtr &res);
  99. void HandleEmbeddingLookup(const DataPtr &data, size_t size, const VectorPtr &res);
  100. void HandleUpdateEmbeddings(const DataPtr &data, size_t size, const VectorPtr &res);
  101. void HandleFinalize(const DataPtr &data, size_t size, const VectorPtr &res);
  102. private:
  103. ParameterServer *ps_;
  104. typedef void (ServerHandler::*RequestHandler)(const DataPtr &data, size_t size, const VectorPtr &res);
  105. std::unordered_map<int, RequestHandler> handlers_;
  106. std::unordered_map<int, std::string> commands_;
  107. std::unordered_map<Key, bool> init_weights_;
  108. std::unordered_map<Key, bool> init_weight_to_optim_;
  109. std::unordered_map<Key, bool> init_optim_info_;
  110. };
  111. bool Init(const FuncGraphPtr &func_graph);
  112. void InitOptimInfoBuilders();
  113. void InitWeightKeyToOptims(const Key &key, const int64_t &optim_id);
  114. void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths);
  115. void InitWeight(const Key &key, const WeightPtr &weight);
  116. void InitGrad(const Key &key, const GradPtr &grad);
  117. void InitEmbeddingTable(const Key &key,
  118. const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
  119. const ParamInitInfo &param_init_info);
  120. bool HasWeight(const Key &key);
  121. void Finalize();
  122. void UpdateWeights();
  123. void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths);
  124. WeightPtr weight(const Key &key);
  125. void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res);
  126. void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals);
  127. inline bool ReadyForUpdateWeights() const;
  128. inline bool ReadyForPush(const Key &key);
  129. inline bool ReadyForPull(const Key &key);
  130. inline void ResetGradAccumCount();
  131. const CNodePtr GetCNode(const std::string &name) const;
  132. inline std::mutex &mutex();
  133. void GetEmbeddingTableParamPtr();
  134. void SyncEmbeddingTables();
  135. size_t pserver_num_;
  136. size_t worker_num_;
  137. size_t grad_accum_count_;
  138. std::unique_ptr<ServerHandler> handler_;
  139. FuncGraphPtr func_graph_;
  140. std::shared_ptr<session::SessionBasic> sess_;
  141. bool running_;
  142. std::unordered_map<Key, std::shared_ptr<PServerKernel>> optimizers_;
  143. std::unordered_map<Key, InputsShapePtr> optim_inputs_shape_;
  144. std::unordered_map<Key, InputsShapePtr> original_optim_inputs_shape_;
  145. std::unordered_map<Key, std::shared_ptr<OptimizerInfo>> optim_infos_;
  146. std::unordered_map<std::string, std::shared_ptr<OptimizerInfoBuilder>> optim_info_builders_;
  147. std::unordered_map<Key, std::string> weight_key_to_optims_;
  148. std::unordered_map<Key, std::string> weight_key_to_optim_op_;
  149. std::unordered_map<Key, WeightPtr> weights_;
  150. std::unordered_map<Key, bool> is_embedding_;
  151. std::unordered_map<Key, WeightPtr> grads_;
  152. std::unordered_map<Key, size_t> grads_accum_counter_;
  153. std::unordered_map<Key, std::shared_ptr<PServerKernel>> embedding_lookup_ops_;
  154. std::unordered_map<Key, uint64_t> tokens_;
  155. std::mutex mutex_;
  156. std::condition_variable apply_grads_cv_;
  157. std::unique_ptr<std::thread> thread_;
  158. std::shared_ptr<core::ServerNode> server_node_;
  159. std::map<Key, ParameterPtr> embedding_tables_;
  160. friend class ServerHandler;
  161. };
  162. } // namespace ps
  163. } // namespace mindspore
  164. #endif // MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_