You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parameter_server.h 35 kB

5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_
  17. #define MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_
  18. #include <unistd.h>
  19. #include <unordered_map>
  20. #include <string>
  21. #include <iostream>
  22. #include <memory>
  23. #include <vector>
  24. #include <mutex>
  25. #include <condition_variable>
  26. #include <thread>
  27. #include <cmath>
  28. #include <random>
  29. #include <utility>
  30. #include <list>
  31. #include <map>
  32. #include <functional>
  33. #include "ir/func_graph.h"
  34. #include "backend/session/session_basic.h"
  35. #include "backend/session/anf_runtime_algorithm.h"
  36. #include "backend/session/session_factory.h"
  37. #include "ps/common.h"
  38. #include "ps/optimizer_info.h"
  39. #include "ps/optimizer_info_builder.h"
  40. #include "ps/util.h"
  41. #include "ps/ps_context.h"
  42. #include "runtime/device/cpu/kernel_select_cpu.h"
  43. #include "utils/ms_context.h"
  44. #include "backend/kernel_compiler/kernel.h"
  45. #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
  46. #include "backend/kernel_compiler/cpu/ps/pserver_kernel.h"
  47. #include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h"
  48. #include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h"
  49. #include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h"
  50. #include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h"
  51. #include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h"
  52. #include "ps/ps_cache/ps_data/ps_data_prefetch.h"
  53. #include "ps/random_normal/random_normal.h"
  54. namespace mindspore {
  55. namespace ps {
  56. using mindspore::kernel::ps::PServerKernel;
  57. using AnfAlgo = session::AnfRuntimeAlgorithm;
  58. template <typename T>
  59. class ParameterServer {
  60. public:
  61. static ParameterServer &GetInstance() {
  62. static ParameterServer instance;
  63. return instance;
  64. }
  65. void Run(const FuncGraphPtr &func_graph);
  66. private:
  67. ParameterServer()
  68. : pserver_num_(0),
  69. worker_num_(0),
  70. rank_id_(0),
  71. grad_accum_count_(0),
  72. ps_(new ::ps::KVServer<T>(0)),
  73. handler_(nullptr),
  74. func_graph_(nullptr),
  75. sess_(nullptr),
  76. running_(true),
  77. thread_(nullptr) {}
  78. ~ParameterServer() = default;
  79. ParameterServer(const ParameterServer &) = delete;
  80. ParameterServer &operator=(const ParameterServer &) = delete;
  81. class ServerHandler {
  82. public:
  83. explicit ServerHandler(ParameterServer *ps) : ps_(ps) {}
  84. ~ServerHandler() = default;
  85. void Init();
  86. void operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVServer<T> *server);
  87. private:
  88. void HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
  89. void HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
  90. void HandleInitWeights(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
  91. void HandleInitWeightToOptimId(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
  92. ::ps::KVPairs<T> *res);
  93. void HandleInitInputsShape(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
  94. void HandleInitEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
  95. void HandleCheckReadyForPush(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
  96. void HandleCheckReadyForPull(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
  97. void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
  98. void HandleUpdateEmbeddings(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
  99. void HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res);
  100. ParameterServer *ps_;
  101. typedef void (ServerHandler::*RequestHandler)(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
  102. ::ps::KVPairs<T> *res);
  103. std::unordered_map<int64_t, RequestHandler> handlers_;
  104. std::unordered_map<Key, bool> init_weights_;
  105. std::unordered_map<Key, bool> init_weight_to_optim_;
  106. std::unordered_map<Key, bool> init_optim_info_;
  107. };
  108. bool Init(const FuncGraphPtr &func_graph);
  109. void InitOptimInfoBuilders();
  110. void InitWeightKeyToOptims(const Key &key, const int64_t &optim_id);
  111. void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths);
  112. void InitWeight(const Key &key, const WeightPtr &weight);
  113. void InitGrad(const Key &key, const GradPtr &grad);
  114. void InitEmbeddingTable(const Key &key,
  115. const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
  116. const ParamInitInfo &param_init_info);
  117. bool HasWeight(const Key &key);
  118. void Finalize();
  119. void UpdateWeights();
  120. void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths);
  121. WeightPtr weight(const Key &key);
  122. void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs<T> *res);
  123. void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals);
  124. bool ReadyForUpdateWeights();
  125. bool ReadyForPush(const Key &key);
  126. bool ReadyForPull(const Key &key);
  127. void ResetGradAccumCount();
  128. const CNodePtr GetCNode(const std::string &name) const;
  129. std::mutex &mutex();
  130. void GetEmbeddingTableParamPtr();
  131. void SyncEmbeddingTables();
  132. size_t pserver_num_;
  133. size_t worker_num_;
  134. size_t rank_id_;
  135. size_t grad_accum_count_;
  136. std::unique_ptr<::ps::KVServer<T>> ps_;
  137. std::unique_ptr<ServerHandler> handler_;
  138. FuncGraphPtr func_graph_;
  139. std::shared_ptr<session::SessionBasic> sess_;
  140. bool running_;
  141. std::unordered_map<Key, std::shared_ptr<PServerKernel>> optimizers_;
  142. std::unordered_map<Key, InputsShapePtr> optim_inputs_shape_;
  143. std::unordered_map<Key, InputsShapePtr> original_optim_inputs_shape_;
  144. std::unordered_map<Key, std::shared_ptr<OptimizerInfo>> optim_infos_;
  145. std::unordered_map<std::string, std::shared_ptr<OptimizerInfoBuilder>> optim_info_builders_;
  146. std::unordered_map<Key, std::string> weight_key_to_optims_;
  147. std::unordered_map<Key, std::string> weight_key_to_optim_op_;
  148. std::unordered_map<Key, WeightPtr> weights_;
  149. std::unordered_map<Key, bool> is_embedding_;
  150. std::unordered_map<Key, WeightPtr> grads_;
  151. std::unordered_map<Key, size_t> grads_accum_counter_;
  152. std::unordered_map<Key, std::shared_ptr<PServerKernel>> embedding_lookup_ops_;
  153. std::unordered_map<Key, uint64_t> tokens_;
  154. std::mutex mutex_;
  155. std::condition_variable apply_grads_cv_;
  156. std::unique_ptr<std::thread> thread_;
  157. std::map<Key, ParameterPtr> embedding_tables_;
  158. friend class ServerHandler;
  159. };
  160. class FuncGraph;
  161. template <typename T>
  162. void ParameterServer<T>::ServerHandler::operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
  163. ::ps::KVServer<T> *server) {
  164. MS_EXCEPTION_IF_NULL(server);
  165. ::ps::KVPairs<T> res;
  166. if (handlers_.count(req_meta.cmd) > 0) {
  167. auto &handler_ptr = handlers_[req_meta.cmd];
  168. (this->*handler_ptr)(req_meta, req_data, &res);
  169. } else if (req_meta.push) {
  170. HandlePushReq(req_meta, req_data, &res);
  171. } else {
  172. HandlePullReq(req_meta, req_data, &res);
  173. }
  174. server->Response(req_meta, res);
  175. }
  176. template <typename T>
  177. void ParameterServer<T>::ServerHandler::Init() {
  178. handlers_[kInitWeightsCmd] = &ServerHandler::HandleInitWeights;
  179. handlers_[kInitWeightToOptimIdCmd] = &ServerHandler::HandleInitWeightToOptimId;
  180. handlers_[kInitOptimInputsShapeCmd] = &ServerHandler::HandleInitInputsShape;
  181. handlers_[kInitEmbeddingsCmd] = &ServerHandler::HandleInitEmbeddings;
  182. handlers_[kCheckReadyForPushCmd] = &ServerHandler::HandleCheckReadyForPush;
  183. handlers_[kCheckReadyForPullCmd] = &ServerHandler::HandleCheckReadyForPull;
  184. handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup;
  185. handlers_[kUpdateEmbeddingsCmd] = &ServerHandler::HandleUpdateEmbeddings;
  186. handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize;
  187. }
  188. template <typename T>
  189. void ParameterServer<T>::ServerHandler::HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
  190. ::ps::KVPairs<T> *res) {
  191. MS_EXCEPTION_IF_NULL(res);
  192. ps_->AccumGrad(req_data.keys, req_data.vals, req_data.lens);
  193. }
  194. template <typename T>
  195. void ParameterServer<T>::ServerHandler::HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
  196. ::ps::KVPairs<T> *res) {
  197. MS_EXCEPTION_IF_NULL(res);
  198. res->keys = req_data.keys;
  199. ::ps::Key key = req_data.keys[0];
  200. res->vals = *(ps_->weight(key));
  201. }
  202. template <typename T>
  203. void ParameterServer<T>::ServerHandler::HandleInitWeights(const ::ps::KVMeta &req_meta,
  204. const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
  205. std::unique_lock<std::mutex> lock(ps_->mutex());
  206. MS_EXCEPTION_IF_NULL(res);
  207. size_t key_num = req_data.keys.size();
  208. T *data_ptr = req_data.vals.data();
  209. size_t pos = 0;
  210. for (size_t i = 0; i < key_num; i++) {
  211. Key key = req_data.keys[i];
  212. size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i];
  213. if (!ps_->HasWeight(key)) {
  214. WeightPtr weight_ptr = std::make_shared<::ps::SArray<T>>();
  215. MS_EXCEPTION_IF_NULL(weight_ptr);
  216. weight_ptr->CopyFrom(data_ptr + pos, data_len);
  217. ps_->InitWeight(key, weight_ptr);
  218. GradPtr grad_ptr = std::make_shared<::ps::SArray<T>>(data_len, 0);
  219. MS_EXCEPTION_IF_NULL(grad_ptr);
  220. ps_->InitGrad(key, grad_ptr);
  221. }
  222. pos += data_len;
  223. }
  224. }
  225. template <typename T>
  226. void ParameterServer<T>::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVMeta &req_meta,
  227. const ::ps::KVPairs<T> &req_data,
  228. ::ps::KVPairs<T> *res) {
  229. std::unique_lock<std::mutex> lock(ps_->mutex());
  230. MS_EXCEPTION_IF_NULL(res);
  231. size_t key_num = req_data.keys.size();
  232. for (size_t i = 0; i < key_num; i++) {
  233. Key key = req_data.keys[i];
  234. T val = req_data.vals[i];
  235. if (init_weight_to_optim_[key]) {
  236. continue;
  237. } else {
  238. init_weight_to_optim_[key] = true;
  239. }
  240. ps_->InitWeightKeyToOptims(key, val);
  241. }
  242. }
  243. template <typename T>
  244. void ParameterServer<T>::ServerHandler::HandleInitInputsShape(const ::ps::KVMeta &req_meta,
  245. const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
  246. std::unique_lock<std::mutex> lock(ps_->mutex());
  247. MS_EXCEPTION_IF_NULL(res);
  248. const Key &key = req_data.keys[0];
  249. if (init_optim_info_[key]) {
  250. return;
  251. } else {
  252. init_optim_info_[key] = true;
  253. }
  254. ps_->InitOptimInputsShape(req_data.keys, req_data.vals, req_data.lens);
  255. }
  256. template <typename T>
  257. void ParameterServer<T>::ServerHandler::HandleInitEmbeddings(const ::ps::KVMeta &req_meta,
  258. const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
  259. std::unique_lock<std::mutex> lock(ps_->mutex());
  260. MS_EXCEPTION_IF_NULL(res);
  261. const Key &key = req_data.keys[0];
  262. MS_LOG(INFO) << "Initializing embedding table for key:" << key;
  263. std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes =
  264. std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>();
  265. MS_EXCEPTION_IF_NULL(shapes);
  266. std::shared_ptr<std::vector<size_t>> input_shape = std::make_shared<std::vector<size_t>>();
  267. MS_EXCEPTION_IF_NULL(input_shape);
  268. std::shared_ptr<std::vector<size_t>> indices_shape = std::make_shared<std::vector<size_t>>();
  269. MS_EXCEPTION_IF_NULL(indices_shape);
  270. std::shared_ptr<std::vector<size_t>> output_shape = std::make_shared<std::vector<size_t>>();
  271. MS_EXCEPTION_IF_NULL(output_shape);
  272. shapes->push_back(input_shape);
  273. shapes->push_back(indices_shape);
  274. shapes->push_back(output_shape);
  275. const Lengths &lens = req_data.lens;
  276. size_t index = 0;
  277. for (int64_t i = 0; i < lens[0]; i++) {
  278. input_shape->push_back(static_cast<size_t>(req_data.vals[index++]));
  279. }
  280. for (int64_t j = 0; j < lens[1]; j++) {
  281. indices_shape->push_back(static_cast<size_t>(req_data.vals[index++]));
  282. }
  283. for (int64_t k = 0; k < lens[2]; k++) {
  284. output_shape->push_back(static_cast<size_t>(req_data.vals[index++]));
  285. }
  286. ParamInitInfo param_init_info;
  287. if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
  288. param_init_info.param_type_ = static_cast<ParamType>(lens[3]);
  289. if (param_init_info.param_type_ == kWeight) {
  290. param_init_info.global_seed_ = static_cast<size_t>(lens[4]);
  291. param_init_info.op_seed_ = static_cast<size_t>(lens[5]);
  292. } else if (param_init_info.param_type_ == kAccumulation) {
  293. param_init_info.init_val_ = req_data.vals[index];
  294. }
  295. }
  296. ps_->InitEmbeddingTable(key, shapes, param_init_info);
  297. }
  298. template <typename T>
  299. void ParameterServer<T>::ServerHandler::HandleCheckReadyForPush(const ::ps::KVMeta &req_meta,
  300. const ::ps::KVPairs<T> &req_data,
  301. ::ps::KVPairs<T> *res) {
  302. MS_EXCEPTION_IF_NULL(res);
  303. const Key &key = req_data.keys[0];
  304. bool ready = ps_->ReadyForPush(key);
  305. res->keys.push_back(key);
  306. res->vals.push_back(ready);
  307. }
  308. template <typename T>
  309. void ParameterServer<T>::ServerHandler::HandleCheckReadyForPull(const ::ps::KVMeta &req_meta,
  310. const ::ps::KVPairs<T> &req_data,
  311. ::ps::KVPairs<T> *res) {
  312. MS_EXCEPTION_IF_NULL(res);
  313. const Key &key = req_data.keys[0];
  314. bool ready = ps_->ReadyForPull(key);
  315. res->keys.push_back(key);
  316. res->vals.push_back(ready);
  317. }
  318. template <typename T>
  319. void ParameterServer<T>::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta,
  320. const ::ps::KVPairs<T> &req_data, ::ps::KVPairs<T> *res) {
  321. MS_EXCEPTION_IF_NULL(res);
  322. const Key &key = req_data.keys[0];
  323. for (size_t i = 1; i < req_data.keys.size(); i++) {
  324. res->keys.push_back(req_data.keys[i]);
  325. }
  326. ps_->DoEmbeddingLookup(key, req_data.keys.segment(1, req_data.keys.size()), res);
  327. }
  328. template <typename T>
  329. void ParameterServer<T>::ServerHandler::HandleUpdateEmbeddings(const ::ps::KVMeta &req_meta,
  330. const ::ps::KVPairs<T> &req_data,
  331. ::ps::KVPairs<T> *res) {
  332. std::unique_lock<std::mutex> lock(ps_->mutex());
  333. MS_EXCEPTION_IF_NULL(res);
  334. const Key &key = req_data.keys[0];
  335. const LookupIds &lookup_ids = req_data.keys.segment(1, req_data.keys.size());
  336. const Values &update_vals = req_data.vals;
  337. ps_->UpdateEmbeddings(key, lookup_ids, update_vals);
  338. }
  339. template <typename T>
  340. void ParameterServer<T>::ServerHandler::HandleFinalize(const ::ps::KVMeta &req_meta, const ::ps::KVPairs<T> &req_data,
  341. ::ps::KVPairs<T> *res) {
  342. MS_EXCEPTION_IF_NULL(res);
  343. ps_->Finalize();
  344. }
  345. template <typename T>
  346. bool ParameterServer<T>::Init(const FuncGraphPtr &func_graph) {
  347. pserver_num_ = ::ps::NumServers();
  348. worker_num_ = ::ps::NumWorkers();
  349. func_graph_ = func_graph;
  350. rank_id_ = ::ps::MyRank();
  351. handler_.reset(new ServerHandler(this));
  352. handler_->Init();
  353. InitOptimInfoBuilders();
  354. ps_->set_request_handle(*handler_);
  355. thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this));
  356. GetEmbeddingTableParamPtr();
  357. return true;
  358. }
  359. template <typename T>
  360. void ParameterServer<T>::InitOptimInfoBuilders() {
  361. std::shared_ptr<OptimizerInfoBuilder> momentum_info_builder = std::make_shared<MomentumOptimInfoBuilder>(worker_num_);
  362. std::shared_ptr<OptimizerInfoBuilder> sparse_adam_info_builder =
  363. std::make_shared<SparseAdamOptimInfoBuilder>(worker_num_);
  364. std::shared_ptr<OptimizerInfoBuilder> sparse_ftrl_info_builder =
  365. std::make_shared<SparseFtrlOptimInfoBuilder>(worker_num_);
  366. optim_info_builders_[kApplyMomentum] = momentum_info_builder;
  367. optim_info_builders_[kSparseAdam] = sparse_adam_info_builder;
  368. optim_info_builders_[kSparseFtrl] = sparse_ftrl_info_builder;
  369. }
  370. template <typename T>
  371. void ParameterServer<T>::InitWeightKeyToOptims(const Key &key, const int64_t &optim_id) {
  372. if (weight_key_to_optims_.count(key) > 0 || Util::optimizer_name(optim_id) == "") {
  373. return;
  374. }
  375. weight_key_to_optims_[key] = Util::optimizer_name(optim_id);
  376. weight_key_to_optim_op_[key] = Util::optimizer_node_name(optim_id);
  377. MS_LOG(INFO) << "Initializing optimizer id for key:" << key << ", optimizer name:" << weight_key_to_optims_[key]
  378. << ", optimizer op name:" << weight_key_to_optim_op_[key];
  379. }
  380. template <typename T>
  381. void ParameterServer<T>::InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths) {
  382. InputsShapePtr inputs_shape = std::make_shared<InputsShape>();
  383. MS_EXCEPTION_IF_NULL(inputs_shape);
  384. InputsShapePtr original_inputs_shape = std::make_shared<InputsShape>();
  385. MS_EXCEPTION_IF_NULL(original_inputs_shape);
  386. int64_t val_idx = 0;
  387. const Key &key = keys[0];
  388. MS_LOG(INFO) << "Initializing optimizer inputs shape for key:" << key;
  389. if (optim_inputs_shape_.count(key) == 0) {
  390. original_optim_inputs_shape_[key] = original_inputs_shape;
  391. optim_inputs_shape_[key] = inputs_shape;
  392. }
  393. for (size_t i = 0; i < keys.size(); i++) {
  394. auto shape = std::make_shared<std::vector<size_t>>();
  395. MS_EXCEPTION_IF_NULL(shape);
  396. auto original_shape = std::make_shared<std::vector<size_t>>();
  397. MS_EXCEPTION_IF_NULL(original_shape);
  398. inputs_shape->push_back(shape);
  399. original_inputs_shape->push_back(original_shape);
  400. for (int64_t j = 0; j < lengths[i]; j++) {
  401. shape->push_back(values[val_idx]);
  402. original_shape->push_back(values[val_idx++]);
  403. }
  404. }
  405. if (weight_key_to_optims_.count(key) > 0) {
  406. const std::string &optim_name = weight_key_to_optims_[key];
  407. const std::string &optim_op_name = weight_key_to_optim_op_[key];
  408. if (optimizers_.count(key) == 0 && optim_inputs_shape_.count(key) > 0) {
  409. const CNodePtr cnode = GetCNode(optim_op_name);
  410. MS_EXCEPTION_IF_NULL(cnode);
  411. if (optim_name == kSparseAdam) {
  412. std::shared_ptr<PServerKernel> optimizer =
  413. std::make_shared<kernel::ps::SparseApplyAdamPSKernel>(rank_id_, pserver_num_, worker_num_);
  414. optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
  415. optimizers_[key] = optimizer;
  416. } else if (optim_name == kSparseLazyAdam) {
  417. std::shared_ptr<PServerKernel> optimizer =
  418. std::make_shared<kernel::ps::SparseApplyLazyAdamPSKernel>(rank_id_, pserver_num_, worker_num_);
  419. optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
  420. optimizers_[key] = optimizer;
  421. } else if (optim_name == kApplyMomentum) {
  422. std::shared_ptr<PServerKernel> optimizer =
  423. std::make_shared<kernel::ps::ApplyMomentumPSKernel>(rank_id_, pserver_num_, worker_num_);
  424. optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
  425. optimizers_[key] = optimizer;
  426. } else if (optim_name == kSparseFtrl) {
  427. std::shared_ptr<PServerKernel> optimizer =
  428. std::make_shared<kernel::ps::SparseApplyFtrlPSKernel>(rank_id_, pserver_num_, worker_num_);
  429. optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
  430. optimizers_[key] = optimizer;
  431. }
  432. }
  433. }
  434. }
  435. template <typename T>
  436. const CNodePtr ParameterServer<T>::GetCNode(const std::string &name) const {
  437. std::list<CNodePtr> cnodes = func_graph_->GetOrderedCnodes();
  438. for (CNodePtr cnode : cnodes) {
  439. MS_EXCEPTION_IF_NULL(cnode);
  440. std::string fullname = cnode->fullname_with_scope();
  441. if (fullname.find(name) != std::string::npos && fullname.find("Push") != std::string::npos) {
  442. return cnode;
  443. }
  444. }
  445. return nullptr;
  446. }
  447. template <typename T>
  448. void ParameterServer<T>::InitWeight(const Key &key, const WeightPtr &weight) {
  449. MS_EXCEPTION_IF_NULL(weight);
  450. if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) {
  451. MS_LOG(INFO) << "Initializing weight for key " << key << ", server rank " << rank_id_;
  452. weights_[key] = weight;
  453. tokens_[key] = 0;
  454. is_embedding_[key] = false;
  455. }
  456. }
  457. template <typename T>
  458. void ParameterServer<T>::InitGrad(const Key &key, const GradPtr &grad) {
  459. MS_EXCEPTION_IF_NULL(grad);
  460. if (grads_.count(key) == 0) {
  461. grads_[key] = grad;
  462. grads_accum_counter_[key] = 0;
  463. }
  464. }
  465. template <typename T>
  466. void ParameterServer<T>::InitEmbeddingTable(
  467. const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
  468. const ParamInitInfo &param_init_info) {
  469. MS_EXCEPTION_IF_NULL(shapes);
  470. if (weights_.count(key) == 0) {
  471. std::shared_ptr<PServerKernel> lookup =
  472. std::make_shared<kernel::ps::EmbeddingLookUpPSKernel>(rank_id_, pserver_num_, worker_num_);
  473. lookup->InitKernel(shapes);
  474. embedding_lookup_ops_[key] = lookup;
  475. // Init embedding weight
  476. const std::vector<size_t> &input_shapes = lookup->input_sizes();
  477. size_t total_dims =
  478. std::accumulate(input_shapes.begin(), input_shapes.end(), IntToSize(1), std::multiplies<size_t>());
  479. WeightPtr embedding = std::make_shared<Weight>(total_dims, 0);
  480. MS_EXCEPTION_IF_NULL(embedding);
  481. T *embedding_data = embedding->data();
  482. std::default_random_engine engine;
  483. std::normal_distribution<float> random(0, 0.01);
  484. if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
  485. if (param_init_info.param_type_ == kWeight) {
  486. InitRandomNormal(0, 0.01, input_shapes, param_init_info.global_seed_, param_init_info.op_seed_, embedding_data);
  487. } else if (param_init_info.param_type_ == kAccumulation) {
  488. for (size_t i = 0; i < total_dims; i++) {
  489. embedding_data[i] = param_init_info.init_val_;
  490. }
  491. }
  492. } else {
  493. for (size_t i = 0; i < total_dims; i++) {
  494. embedding_data[i] = random(engine);
  495. }
  496. }
  497. weights_[key] = embedding;
  498. tokens_[key] = 0;
  499. is_embedding_[key] = true;
  500. grads_accum_counter_[key] = 0;
  501. }
  502. }
  503. template <typename T>
  504. bool ParameterServer<T>::HasWeight(const Key &key) {
  505. return (weights_.count(key) > 0 && !is_embedding_.count(key));
  506. }
  507. template <typename T>
  508. void ParameterServer<T>::Finalize() {
  509. running_ = false;
  510. apply_grads_cv_.notify_one();
  511. }
  512. template <typename T>
  513. void ParameterServer<T>::UpdateWeights() {
  514. while (true) {
  515. std::unique_lock<std::mutex> lock(mutex_);
  516. apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights() || !running_; });
  517. if (!running_) {
  518. break;
  519. }
  520. for (auto iter = weights_.begin(); iter != weights_.end(); iter++) {
  521. Key key = iter->first;
  522. WeightPtr weight_ptr = iter->second;
  523. std::shared_ptr<PServerKernel> optimizer = nullptr;
  524. if (weight_key_to_optims_.count(key) > 0) {
  525. optimizer = optimizers_[key];
  526. }
  527. MS_EXCEPTION_IF_NULL(optimizer);
  528. std::shared_ptr<OptimizerInfo> optim_info = optim_infos_[key];
  529. if (optim_info != nullptr) {
  530. const std::vector<kernel::AddressPtr> &inputs = optim_info->inputs();
  531. const std::vector<kernel::AddressPtr> &workspaces = optim_info->workspaces();
  532. const std::vector<kernel::AddressPtr> &outputs = optim_info->outputs();
  533. std::vector<std::vector<size_t>> shapes = {};
  534. std::vector<size_t> indices_shape = {};
  535. indices_shape.emplace_back(optim_info->indice_size());
  536. shapes.push_back(indices_shape);
  537. if (original_optim_inputs_shape_.count(key) != 0) {
  538. for (auto input_shapes : *(original_optim_inputs_shape_[key])) {
  539. shapes.push_back(*input_shapes);
  540. }
  541. }
  542. optimizer->ReInit(shapes);
  543. optim_info->ComputeMean(shapes, worker_num_, pserver_num_, rank_id_);
  544. optimizer->Execute(inputs, workspaces, outputs);
  545. optim_info->Reset();
  546. }
  547. if (!is_embedding_[key]) {
  548. tokens_[key] = worker_num_;
  549. }
  550. }
  551. ResetGradAccumCount();
  552. }
  553. }
  554. template <typename T>
  555. void ParameterServer<T>::AccumGrad(const Keys &keys, const Values &values, const Lengths &lengths) {
  556. std::unique_lock<std::mutex> lock(mutex_);
  557. const Key &key = keys[0];
  558. bool no_sparse_grad = values.size() == 1 && values[0] == -100;
  559. if (!no_sparse_grad) {
  560. std::shared_ptr<OptimizerInfo> optim_info = optim_infos_[key];
  561. // Create or update the optimizer info
  562. if (optim_info == nullptr) {
  563. const std::shared_ptr<OptimizerInfoBuilder> &builder = optim_info_builders_[weight_key_to_optims_[key]];
  564. std::shared_ptr<kernel::ps::PServerKernel> pserver_kernel = optimizers_[key];
  565. if (pserver_kernel == nullptr) {
  566. MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key];
  567. }
  568. MS_EXCEPTION_IF_NULL(pserver_kernel);
  569. OptimizerInfo *optim = builder->Build(pserver_kernel, weights_[key], keys, values, lengths,
  570. optim_inputs_shape_[key], worker_num_, is_embedding_[key]);
  571. optim_info.reset(optim);
  572. optim_infos_[key] = optim_info;
  573. } else {
  574. optim_info->Update(values, lengths);
  575. optim_info->Accumulate(values, lengths);
  576. }
  577. }
  578. grads_accum_counter_[key] += 1;
  579. if (grads_accum_counter_[key] == worker_num_) {
  580. grad_accum_count_++;
  581. }
  582. if (ReadyForUpdateWeights()) {
  583. apply_grads_cv_.notify_one();
  584. }
  585. }
  586. template <typename T>
  587. WeightPtr ParameterServer<T>::weight(const Key &key) {
  588. std::unique_lock<std::mutex> lock(mutex_);
  589. if (weights_.count(key) == 0) {
  590. MS_LOG(EXCEPTION) << "Invalid weight key " << key;
  591. }
  592. WeightPtr weight_ptr = weights_[key];
  593. MS_EXCEPTION_IF_NULL(weight_ptr);
  594. WeightPtr copy_weight_ptr = std::make_shared<::ps::SArray<T>>(weight_ptr->size(), 0);
  595. MS_EXCEPTION_IF_NULL(copy_weight_ptr);
  596. copy_weight_ptr->CopyFrom(weight_ptr->data(), weight_ptr->size());
  597. tokens_[key] -= 1;
  598. return copy_weight_ptr;
  599. }
  600. template <typename T>
  601. void ParameterServer<T>::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs<T> *res) {
  602. std::unique_lock<std::mutex> lock(mutex_);
  603. MS_EXCEPTION_IF_NULL(res);
  604. if (weights_.count(key) == 0) {
  605. MS_LOG(ERROR) << "Invalid embedding table key " << key;
  606. return;
  607. }
  608. if (embedding_lookup_ops_.count(key) == 0) {
  609. MS_LOG(ERROR) << "Invalid embedding lookup op key " << key;
  610. return;
  611. }
  612. WeightPtr table_ptr = weights_[key];
  613. MS_EXCEPTION_IF_NULL(table_ptr);
  614. std::shared_ptr<PServerKernel> table_lookup_op = embedding_lookup_ops_[key];
  615. MS_EXCEPTION_IF_NULL(table_lookup_op);
  616. // Update shapes of lookup operator
  617. std::vector<std::vector<size_t>> shapes = {};
  618. std::vector<size_t> indices_shape = {};
  619. indices_shape.emplace_back(lookup_ids.size());
  620. shapes.push_back(indices_shape);
  621. table_lookup_op->ReInit(shapes);
  622. const std::vector<size_t> output_shapes = table_lookup_op->output_sizes();
  623. std::vector<kernel::AddressPtr> inputs;
  624. AddressPtr embedding_table = std::make_shared<kernel::Address>();
  625. MS_EXCEPTION_IF_NULL(embedding_table);
  626. AddressPtr indices = std::make_shared<kernel::Address>();
  627. MS_EXCEPTION_IF_NULL(indices);
  628. inputs.push_back(embedding_table);
  629. inputs.push_back(indices);
  630. embedding_table->addr = table_ptr->data();
  631. embedding_table->size = table_ptr->size() * sizeof(T);
  632. std::unique_ptr<int[]> tmp_ids(new int[lookup_ids.size()]);
  633. MS_EXCEPTION_IF_NULL(tmp_ids);
  634. for (size_t i = 0; i < lookup_ids.size(); i++) {
  635. tmp_ids[i] = static_cast<int>(lookup_ids[i]);
  636. }
  637. indices->addr = tmp_ids.get();
  638. indices->size = lookup_ids.size() * sizeof(int);
  639. std::vector<kernel::AddressPtr> workspaces;
  640. std::vector<kernel::AddressPtr> outputs;
  641. AddressPtr output = std::make_shared<kernel::Address>();
  642. MS_EXCEPTION_IF_NULL(output);
  643. std::shared_ptr<Values> addr = std::make_shared<Values>(output_shapes[0] / sizeof(T), 0);
  644. MS_EXCEPTION_IF_NULL(addr);
  645. output->addr = addr->data();
  646. output->size = output_shapes[0];
  647. outputs.push_back(output);
  648. table_lookup_op->Execute(inputs, workspaces, outputs);
  649. res->vals = *addr;
  650. res->lens.push_back(res->vals.size());
  651. }
  652. template <typename T>
  653. void ParameterServer<T>::UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals) {
  654. if (weights_.count(key) == 0) {
  655. MS_LOG(ERROR) << "Invalid embedding table key " << key;
  656. return;
  657. }
  658. if (embedding_lookup_ops_.count(key) == 0) {
  659. MS_LOG(ERROR) << "Invalid embedding lookup op key " << key;
  660. return;
  661. }
  662. WeightPtr table_ptr = weights_[key];
  663. MS_EXCEPTION_IF_NULL(table_ptr);
  664. std::shared_ptr<PServerKernel> table_lookup_op = embedding_lookup_ops_[key];
  665. MS_EXCEPTION_IF_NULL(table_lookup_op);
  666. table_lookup_op->UpdateEmbeddings(table_ptr->data(), lookup_ids.data(), vals.data(), lookup_ids.size());
  667. }
  668. template <typename T>
  669. inline bool ParameterServer<T>::ReadyForUpdateWeights() {
  670. return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size();
  671. }
  672. template <typename T>
  673. inline bool ParameterServer<T>::ReadyForPush(const Key &key) {
  674. std::unique_lock<std::mutex> lock(mutex_);
  675. if (weights_.empty()) {
  676. MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send "
  677. "kInitWeightsCmd command. 2.The Server failed to initialize weights.";
  678. }
  679. return grad_accum_count_ < weights_.size() && tokens_[key] <= 0;
  680. }
  681. template <typename T>
  682. inline bool ParameterServer<T>::ReadyForPull(const Key &key) {
  683. std::unique_lock<std::mutex> lock(mutex_);
  684. if (tokens_.count(key) == 0 || weights_[key] == 0) {
  685. MS_LOG(EXCEPTION) << "Invalid weight key " << key;
  686. }
  687. return tokens_[key] > 0;
  688. }
  689. template <typename T>
  690. inline void ParameterServer<T>::ResetGradAccumCount() {
  691. grad_accum_count_ = 0;
  692. for (auto iter = grads_accum_counter_.begin(); iter != grads_accum_counter_.end(); iter++) {
  693. grads_accum_counter_[iter->first] = 0;
  694. }
  695. }
  696. template <typename T>
  697. inline std::mutex &ParameterServer<T>::mutex() {
  698. return mutex_;
  699. }
  700. template <typename T>
  701. void ParameterServer<T>::GetEmbeddingTableParamPtr() {
  702. MS_EXCEPTION_IF_NULL(func_graph_);
  703. auto cnodes = func_graph_->GetOrderedCnodes();
  704. Key count = 0;
  705. for (auto cnode : cnodes) {
  706. MS_EXCEPTION_IF_NULL(cnode);
  707. std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
  708. if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName || cnode_name == kSparseGatherV2OpName) {
  709. auto embedding_table = AnfAlgo::GetInputNode(cnode, 0);
  710. if (IsPrimitiveCNode(embedding_table, prim::kPrimLoad)) {
  711. auto embedding_cnode = embedding_table->cast<CNodePtr>();
  712. embedding_table = AnfAlgo::GetInputNode(embedding_cnode, 0);
  713. }
  714. MS_EXCEPTION_IF_NULL(embedding_table);
  715. if (embedding_table->isa<Parameter>()) {
  716. MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count;
  717. embedding_tables_.insert(std::make_pair(count, embedding_table->cast<ParameterPtr>()));
  718. count++;
  719. }
  720. }
  721. }
  722. }
  723. template <typename T>
  724. void ParameterServer<T>::SyncEmbeddingTables() {
  725. for (auto embedding_table : embedding_tables_) {
  726. Key key = embedding_table.first;
  727. if (embedding_lookup_ops_.count(key) == 0) {
  728. MS_LOG(WARNING) << "Can't find look up PS kernel for key " << key;
  729. continue;
  730. }
  731. auto lookup = embedding_lookup_ops_[key];
  732. const std::vector<size_t> &input_shapes = lookup->input_sizes();
  733. std::vector<int64_t> new_tensor_shape(input_shapes.begin(), input_shapes.end());
  734. tensor::TensorPtr new_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, new_tensor_shape);
  735. MS_EXCEPTION_IF_NULL(new_tensor);
  736. float *new_tensor_data_ptr = reinterpret_cast<float *>(new_tensor->data_c());
  737. size_t new_tensor_size = static_cast<size_t>(new_tensor->data().nbytes());
  738. size_t embedding_table_size = weights_[key]->size() * sizeof(float);
  739. if (new_tensor_size != embedding_table_size) {
  740. MS_LOG(EXCEPTION) << "Shape of embedding table can't match. New tensor size:" << new_tensor_size
  741. << ", embedding_table size:" << embedding_table_size;
  742. }
  743. MS_EXCEPTION_IF_NULL(new_tensor_data_ptr);
  744. MS_EXCEPTION_IF_NULL(weights_[key]->data());
  745. int64_t ret = memcpy_s(new_tensor_data_ptr, new_tensor_size, weights_[key]->data(), embedding_table_size);
  746. if (ret != 0) {
  747. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  748. return;
  749. }
  750. auto paramter_tensor_ptr = embedding_table.second->default_param();
  751. MS_EXCEPTION_IF_NULL(paramter_tensor_ptr);
  752. paramter_tensor_ptr->cast<tensor::TensorPtr>()->AssignValue(*new_tensor);
  753. }
  754. }
  755. template <typename T>
  756. void ParameterServer<T>::Run(const FuncGraphPtr &func_graph) {
  757. MS_EXCEPTION_IF_NULL(func_graph);
  758. MS_LOG(INFO) << "PServer starts connecting to scheduler and workers...";
  759. ::ps::Start(0);
  760. MS_LOG(INFO) << "PServer connected successfully.";
  761. if (!::ps::IsServer()) {
  762. std::cout << "This is not ther Server" << std::endl;
  763. return;
  764. }
  765. Init(func_graph);
  766. PSContext::instance()->SetPSRankId(rank_id_);
  767. thread_->join();
  768. SyncEmbeddingTables();
  769. MS_LOG(INFO) << "PServer finished updating models, starts finalizing...";
  770. ::ps::Finalize(0, true);
  771. MS_LOG(INFO) << "PServer finalized successfully.";
  772. }
  773. } // namespace ps
  774. } // namespace mindspore
  775. #endif // MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_