You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parameter_server.cc 30 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ps/parameter_server.h"
  17. namespace mindspore {
  18. namespace ps {
  19. void ParameterServer::Run(const FuncGraphPtr &func_graph) {
  20. MS_EXCEPTION_IF_NULL(func_graph);
  21. MS_LOG(INFO) << "PServer starts connecting to scheduler and workers...";
  22. server_node_ = std::make_shared<core::ServerNode>();
  23. MS_LOG(INFO) << "PServer connected successfully.";
  24. if (!PSContext::instance()->is_server()) {
  25. MS_LOG(INFO) << "This is not the Server node.";
  26. return;
  27. }
  28. Init(func_graph);
  29. server_node_->Start();
  30. PSContext::instance()->SetPSRankId(server_node_->rank_id());
  31. thread_->join();
  32. SyncEmbeddingTables();
  33. MS_LOG(INFO) << "PServer finished updating models, starts finalizing...";
  34. server_node_->Finish();
  35. server_node_->Stop();
  36. MS_LOG(INFO) << "PServer finalized successfully.";
  37. }
  38. bool ParameterServer::Init(const FuncGraphPtr &func_graph) {
  39. pserver_num_ = std::strtol(mindspore::common::GetEnv(kEnvPServerNum).c_str(), nullptr, 10);
  40. worker_num_ = std::strtol(mindspore::common::GetEnv(kEnvWorkerNum).c_str(), nullptr, 10);
  41. func_graph_ = func_graph;
  42. handler_.reset(new ServerHandler(this));
  43. handler_->Init();
  44. InitOptimInfoBuilders();
  45. server_node_->set_handler(*handler_);
  46. server_node_->set_event_callback([&](const core::NodeEvent &event) {
  47. if ((event == core::NodeEvent::CLUSTER_TIMEOUT) ||
  48. (event == core::NodeEvent::SCHEDULER_TIMEOUT || (event == core::NodeEvent::NODE_TIMEOUT))) {
  49. MS_LOG(ERROR) << "Trigger timeout event:" << event << " begin to exit the system!";
  50. Finalize();
  51. }
  52. });
  53. thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this));
  54. GetEmbeddingTableParamPtr();
  55. return true;
  56. }
  57. void ParameterServer::InitOptimInfoBuilders() {
  58. std::shared_ptr<OptimizerInfoBuilder> momentum_info_builder = std::make_shared<MomentumOptimInfoBuilder>(worker_num_);
  59. std::shared_ptr<OptimizerInfoBuilder> sparse_adam_info_builder =
  60. std::make_shared<SparseAdamOptimInfoBuilder>(worker_num_);
  61. std::shared_ptr<OptimizerInfoBuilder> sparse_ftrl_info_builder =
  62. std::make_shared<SparseFtrlOptimInfoBuilder>(worker_num_);
  63. optim_info_builders_[kApplyMomentum] = momentum_info_builder;
  64. optim_info_builders_[kSparseAdam] = sparse_adam_info_builder;
  65. optim_info_builders_[kSparseFtrl] = sparse_ftrl_info_builder;
  66. }
  67. void ParameterServer::InitWeightKeyToOptims(const Key &key, const int64_t &optim_id) {
  68. if (weight_key_to_optims_.count(key) > 0 || Util::optimizer_name(optim_id) == "") {
  69. return;
  70. }
  71. weight_key_to_optims_[key] = Util::optimizer_name(optim_id);
  72. weight_key_to_optim_op_[key] = Util::optimizer_node_name(optim_id);
  73. MS_LOG(INFO) << "Initializing optimizer id for key:" << key << ", optimizer name:" << weight_key_to_optims_[key]
  74. << ", optimizer op name:" << weight_key_to_optim_op_[key];
  75. }
  76. void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths) {
  77. InputsShapePtr inputs_shape = std::make_shared<InputsShape>();
  78. MS_EXCEPTION_IF_NULL(inputs_shape);
  79. InputsShapePtr original_inputs_shape = std::make_shared<InputsShape>();
  80. MS_EXCEPTION_IF_NULL(original_inputs_shape);
  81. int64_t val_idx = 0;
  82. const Key &key = keys[0];
  83. MS_LOG(INFO) << "Initializing optimizer inputs shape for key:" << key;
  84. if (optim_inputs_shape_.count(key) == 0) {
  85. original_optim_inputs_shape_[key] = original_inputs_shape;
  86. optim_inputs_shape_[key] = inputs_shape;
  87. }
  88. for (size_t i = 0; i < keys.size(); i++) {
  89. auto shape = std::make_shared<std::vector<size_t>>();
  90. MS_EXCEPTION_IF_NULL(shape);
  91. auto original_shape = std::make_shared<std::vector<size_t>>();
  92. MS_EXCEPTION_IF_NULL(original_shape);
  93. inputs_shape->push_back(shape);
  94. original_inputs_shape->push_back(original_shape);
  95. for (int64_t j = 0; j < lengths[i]; j++) {
  96. shape->push_back(values[val_idx]);
  97. original_shape->push_back(values[val_idx++]);
  98. }
  99. }
  100. if (weight_key_to_optims_.count(key) > 0) {
  101. const std::string &optim_name = weight_key_to_optims_[key];
  102. const std::string &optim_op_name = weight_key_to_optim_op_[key];
  103. if (optimizers_.count(key) == 0 && optim_inputs_shape_.count(key) > 0) {
  104. const CNodePtr cnode = GetCNode(optim_op_name);
  105. MS_EXCEPTION_IF_NULL(cnode);
  106. if (optim_name == kSparseAdam) {
  107. std::shared_ptr<PServerKernel> optimizer =
  108. std::make_shared<kernel::ps::SparseApplyAdamPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
  109. optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
  110. optimizers_[key] = optimizer;
  111. } else if (optim_name == kSparseLazyAdam) {
  112. std::shared_ptr<PServerKernel> optimizer =
  113. std::make_shared<kernel::ps::SparseApplyLazyAdamPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
  114. optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
  115. optimizers_[key] = optimizer;
  116. } else if (optim_name == kApplyMomentum) {
  117. std::shared_ptr<PServerKernel> optimizer =
  118. std::make_shared<kernel::ps::ApplyMomentumPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
  119. optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
  120. optimizers_[key] = optimizer;
  121. } else if (optim_name == kSparseFtrl) {
  122. std::shared_ptr<PServerKernel> optimizer =
  123. std::make_shared<kernel::ps::SparseApplyFtrlPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
  124. optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
  125. optimizers_[key] = optimizer;
  126. }
  127. }
  128. }
  129. }
  130. void ParameterServer::InitWeight(const Key &key, const WeightPtr &weight) {
  131. MS_EXCEPTION_IF_NULL(weight);
  132. if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) {
  133. MS_LOG(INFO) << "Initializing weight for key " << key << ", server rank " << server_node_->rank_id();
  134. weights_[key] = weight;
  135. tokens_[key] = 0;
  136. is_embedding_[key] = false;
  137. }
  138. }
  139. void ParameterServer::InitGrad(const Key &key, const GradPtr &grad) {
  140. MS_EXCEPTION_IF_NULL(grad);
  141. if (grads_.count(key) == 0) {
  142. grads_[key] = grad;
  143. grads_accum_counter_[key] = 0;
  144. }
  145. }
  146. void ParameterServer::InitEmbeddingTable(
  147. const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
  148. const ParamInitInfo &param_init_info) {
  149. MS_EXCEPTION_IF_NULL(shapes);
  150. if (weights_.count(key) == 0) {
  151. std::shared_ptr<PServerKernel> lookup =
  152. std::make_shared<kernel::ps::EmbeddingLookUpPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
  153. lookup->InitKernel(shapes);
  154. embedding_lookup_ops_[key] = lookup;
  155. // Init embedding weight
  156. const std::vector<size_t> &input_shapes = lookup->input_sizes();
  157. size_t total_dims =
  158. std::accumulate(input_shapes.begin(), input_shapes.end(), IntToSize(1), std::multiplies<size_t>());
  159. WeightPtr embedding = std::make_shared<Weight>(total_dims, 0);
  160. MS_EXCEPTION_IF_NULL(embedding);
  161. float *embedding_data = embedding->data();
  162. std::default_random_engine engine;
  163. std::normal_distribution<float> random(0, 0.01);
  164. if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
  165. if (param_init_info.param_type_ == kWeight) {
  166. InitRandomNormal(0, 0.01, input_shapes, param_init_info.global_seed_, param_init_info.op_seed_, embedding_data);
  167. } else if (param_init_info.param_type_ == kAccumulation) {
  168. for (size_t i = 0; i < total_dims; i++) {
  169. embedding_data[i] = param_init_info.init_val_;
  170. }
  171. }
  172. } else {
  173. for (size_t i = 0; i < total_dims; i++) {
  174. embedding_data[i] = random(engine);
  175. }
  176. }
  177. weights_[key] = embedding;
  178. MS_LOG(DEBUG) << "The key:" << key << " the embedding:" << *embedding;
  179. tokens_[key] = 0;
  180. is_embedding_[key] = true;
  181. grads_accum_counter_[key] = 0;
  182. }
  183. }
  184. bool ParameterServer::HasWeight(const Key &key) { return (weights_.count(key) > 0 && !is_embedding_.count(key)); }
  185. void ParameterServer::Finalize() {
  186. running_ = false;
  187. apply_grads_cv_.notify_one();
  188. }
  189. void ParameterServer::UpdateWeights() {
  190. while (true) {
  191. MS_LOG(INFO) << "The running is:" << running_ << " the ready is:" << this->ReadyForUpdateWeights();
  192. std::unique_lock<std::mutex> lock(mutex_);
  193. apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights() || !running_; });
  194. if (!running_) {
  195. break;
  196. }
  197. for (auto iter = weights_.begin(); iter != weights_.end(); iter++) {
  198. Key key = iter->first;
  199. WeightPtr weight_ptr = iter->second;
  200. std::shared_ptr<PServerKernel> optimizer = nullptr;
  201. if (weight_key_to_optims_.count(key) > 0) {
  202. optimizer = optimizers_[key];
  203. }
  204. MS_EXCEPTION_IF_NULL(optimizer);
  205. std::shared_ptr<OptimizerInfo> optim_info = optim_infos_[key];
  206. if (optim_info != nullptr) {
  207. const std::vector<kernel::AddressPtr> &inputs = optim_info->inputs();
  208. const std::vector<kernel::AddressPtr> &workspaces = optim_info->workspaces();
  209. const std::vector<kernel::AddressPtr> &outputs = optim_info->outputs();
  210. std::vector<std::vector<size_t>> shapes = {};
  211. std::vector<size_t> indices_shape = {};
  212. indices_shape.emplace_back(optim_info->indice_size());
  213. shapes.push_back(indices_shape);
  214. if (original_optim_inputs_shape_.count(key) != 0) {
  215. std::transform(
  216. (*(original_optim_inputs_shape_[key])).begin(), (*(original_optim_inputs_shape_[key])).end(),
  217. std::back_inserter(shapes),
  218. [](std::shared_ptr<std::vector<size_t>> input_shapes) -> std::vector<size_t> { return *input_shapes; });
  219. }
  220. optimizer->ReInit(shapes);
  221. optim_info->ComputeMean(shapes, worker_num_, pserver_num_, server_node_->rank_id());
  222. optimizer->Execute(inputs, workspaces, outputs);
  223. optim_info->Reset();
  224. }
  225. if (!is_embedding_[key]) {
  226. tokens_[key] = worker_num_;
  227. }
  228. }
  229. ResetGradAccumCount();
  230. }
  231. }
  232. void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const Lengths &lengths) {
  233. std::unique_lock<std::mutex> lock(mutex_);
  234. const Key &key = keys[0];
  235. bool no_sparse_grad = values.size() == 1 && values[0] == -100;
  236. if (!no_sparse_grad) {
  237. std::shared_ptr<OptimizerInfo> optim_info = optim_infos_[key];
  238. // Create or update the optimizer info
  239. if (optim_info == nullptr) {
  240. const std::shared_ptr<OptimizerInfoBuilder> &builder = optim_info_builders_[weight_key_to_optims_[key]];
  241. std::shared_ptr<kernel::ps::PServerKernel> pserver_kernel = optimizers_[key];
  242. if (pserver_kernel == nullptr) {
  243. MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key];
  244. }
  245. MS_EXCEPTION_IF_NULL(pserver_kernel);
  246. OptimizerInfo *optim = builder->Build(pserver_kernel, weights_[key], keys, values, lengths,
  247. optim_inputs_shape_[key], worker_num_, is_embedding_[key]);
  248. optim_info.reset(optim);
  249. optim_infos_[key] = optim_info;
  250. } else {
  251. optim_info->Update(values, lengths);
  252. optim_info->Accumulate(values, lengths);
  253. }
  254. }
  255. grads_accum_counter_[key] += 1;
  256. if (grads_accum_counter_[key] == worker_num_) {
  257. grad_accum_count_++;
  258. }
  259. if (ReadyForUpdateWeights()) {
  260. apply_grads_cv_.notify_one();
  261. }
  262. }
  263. WeightPtr ParameterServer::weight(const Key &key) {
  264. std::unique_lock<std::mutex> lock(mutex_);
  265. if (weights_.count(key) == 0) {
  266. MS_LOG(EXCEPTION) << "Invalid weight key " << key;
  267. }
  268. WeightPtr weight_ptr = weights_[key];
  269. MS_EXCEPTION_IF_NULL(weight_ptr);
  270. WeightPtr copy_weight_ptr = std::make_shared<std::vector<float>>(weight_ptr->size(), 0);
  271. MS_EXCEPTION_IF_NULL(copy_weight_ptr);
  272. copy_weight_ptr = weight_ptr;
  273. tokens_[key] -= 1;
  274. return copy_weight_ptr;
  275. }
  276. void ParameterServer::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res) {
  277. std::unique_lock<std::mutex> lock(mutex_);
  278. MS_EXCEPTION_IF_NULL(res);
  279. if (weights_.count(key) == 0) {
  280. MS_LOG(ERROR) << "Invalid embedding table key " << key;
  281. return;
  282. }
  283. if (embedding_lookup_ops_.count(key) == 0) {
  284. MS_LOG(ERROR) << "Invalid embedding lookup op key " << key;
  285. return;
  286. }
  287. WeightPtr table_ptr = weights_[key];
  288. MS_EXCEPTION_IF_NULL(table_ptr);
  289. std::shared_ptr<PServerKernel> table_lookup_op = embedding_lookup_ops_[key];
  290. MS_EXCEPTION_IF_NULL(table_lookup_op);
  291. // Update shapes of lookup operator
  292. std::vector<std::vector<size_t>> shapes = {};
  293. std::vector<size_t> indices_shape = {};
  294. indices_shape.emplace_back(lookup_ids.size());
  295. shapes.push_back(indices_shape);
  296. table_lookup_op->ReInit(shapes);
  297. const std::vector<size_t> output_shapes = table_lookup_op->output_sizes();
  298. std::vector<kernel::AddressPtr> inputs;
  299. AddressPtr embedding_table = std::make_shared<kernel::Address>();
  300. MS_EXCEPTION_IF_NULL(embedding_table);
  301. AddressPtr indices = std::make_shared<kernel::Address>();
  302. MS_EXCEPTION_IF_NULL(indices);
  303. inputs.push_back(embedding_table);
  304. inputs.push_back(indices);
  305. embedding_table->addr = table_ptr->data();
  306. embedding_table->size = table_ptr->size() * sizeof(float);
  307. std::unique_ptr<int[]> tmp_ids(new int[lookup_ids.size()]);
  308. MS_EXCEPTION_IF_NULL(tmp_ids);
  309. for (size_t i = 0; i < lookup_ids.size(); i++) {
  310. tmp_ids[i] = static_cast<int>(lookup_ids[i]);
  311. }
  312. indices->addr = tmp_ids.get();
  313. indices->size = lookup_ids.size() * sizeof(int);
  314. std::vector<kernel::AddressPtr> workspaces;
  315. std::vector<kernel::AddressPtr> outputs;
  316. AddressPtr output = std::make_shared<kernel::Address>();
  317. MS_EXCEPTION_IF_NULL(output);
  318. std::shared_ptr<Values> addr = std::make_shared<Values>(output_shapes[0] / sizeof(float), 0);
  319. MS_EXCEPTION_IF_NULL(addr);
  320. output->addr = addr->data();
  321. output->size = output_shapes[0];
  322. outputs.push_back(output);
  323. table_lookup_op->Execute(inputs, workspaces, outputs);
  324. *res->mutable_values() = {addr->begin(), addr->end()};
  325. res->add_len(res->values_size());
  326. }
  327. void ParameterServer::UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals) {
  328. if (weights_.count(key) == 0) {
  329. MS_LOG(ERROR) << "Invalid embedding table key " << key;
  330. return;
  331. }
  332. if (embedding_lookup_ops_.count(key) == 0) {
  333. MS_LOG(ERROR) << "Invalid embedding lookup op key " << key;
  334. return;
  335. }
  336. WeightPtr table_ptr = weights_[key];
  337. MS_EXCEPTION_IF_NULL(table_ptr);
  338. std::shared_ptr<PServerKernel> table_lookup_op = embedding_lookup_ops_[key];
  339. MS_EXCEPTION_IF_NULL(table_lookup_op);
  340. table_lookup_op->UpdateEmbeddings(table_ptr->data(), lookup_ids.data(), vals.data(), lookup_ids.size());
  341. }
  342. inline bool ParameterServer::ReadyForUpdateWeights() {
  343. return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size();
  344. }
  345. inline bool ParameterServer::ReadyForPush(const Key &key) {
  346. std::unique_lock<std::mutex> lock(mutex_);
  347. if (weights_.empty()) {
  348. MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send "
  349. "kInitWeightsCmd command. 2.The Server failed to initialize weights.";
  350. }
  351. MS_LOG(INFO) << "The grad_accum_count_:" << grad_accum_count_ << " the weights_:" << weights_.size()
  352. << " the token:" << (tokens_[key] <= 0);
  353. return grad_accum_count_ < weights_.size() && tokens_[key] <= 0;
  354. }
  355. inline bool ParameterServer::ReadyForPull(const Key &key) {
  356. std::unique_lock<std::mutex> lock(mutex_);
  357. if (tokens_.count(key) == 0 || weights_[key] == 0) {
  358. MS_LOG(EXCEPTION) << "Invalid weight key " << key;
  359. }
  360. MS_LOG(INFO) << "ReadyForPull: " << (tokens_[key] > 0);
  361. return tokens_[key] > 0;
  362. }
  363. inline void ParameterServer::ResetGradAccumCount() {
  364. grad_accum_count_ = 0;
  365. for (auto iter = grads_accum_counter_.begin(); iter != grads_accum_counter_.end(); iter++) {
  366. grads_accum_counter_[iter->first] = 0;
  367. }
  368. }
  369. const CNodePtr ParameterServer::GetCNode(const std::string &name) const {
  370. std::list<CNodePtr> cnodes = func_graph_->GetOrderedCnodes();
  371. for (CNodePtr cnode : cnodes) {
  372. MS_EXCEPTION_IF_NULL(cnode);
  373. std::string fullname = cnode->fullname_with_scope();
  374. if (fullname.find(name) != std::string::npos && fullname.find("Push") != std::string::npos) {
  375. return cnode;
  376. }
  377. }
  378. return nullptr;
  379. }
  380. inline std::mutex &ParameterServer::mutex() { return mutex_; }
  381. void ParameterServer::GetEmbeddingTableParamPtr() {
  382. MS_EXCEPTION_IF_NULL(func_graph_);
  383. auto cnodes = func_graph_->GetOrderedCnodes();
  384. Key count = 0;
  385. for (auto cnode : cnodes) {
  386. MS_EXCEPTION_IF_NULL(cnode);
  387. std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
  388. if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName || cnode_name == kSparseGatherV2OpName) {
  389. auto embedding_table = AnfAlgo::GetInputNode(cnode, 0);
  390. if (IsPrimitiveCNode(embedding_table, prim::kPrimLoad)) {
  391. auto embedding_cnode = embedding_table->cast<CNodePtr>();
  392. embedding_table = AnfAlgo::GetInputNode(embedding_cnode, 0);
  393. }
  394. MS_EXCEPTION_IF_NULL(embedding_table);
  395. if (embedding_table->isa<Parameter>()) {
  396. MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count;
  397. embedding_tables_.insert(std::make_pair(count, embedding_table->cast<ParameterPtr>()));
  398. count++;
  399. }
  400. }
  401. }
  402. }
  403. void ParameterServer::SyncEmbeddingTables() {
  404. for (auto embedding_table : embedding_tables_) {
  405. Key key = embedding_table.first;
  406. if (embedding_lookup_ops_.count(key) == 0) {
  407. MS_LOG(WARNING) << "Can't find look up PS kernel for key " << key;
  408. continue;
  409. }
  410. auto lookup = embedding_lookup_ops_[key];
  411. const std::vector<size_t> &input_shapes = lookup->input_sizes();
  412. std::vector<int64_t> new_tensor_shape(input_shapes.begin(), input_shapes.end());
  413. tensor::TensorPtr new_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, new_tensor_shape);
  414. MS_EXCEPTION_IF_NULL(new_tensor);
  415. float *new_tensor_data_ptr = reinterpret_cast<float *>(new_tensor->data_c());
  416. size_t new_tensor_size = static_cast<size_t>(new_tensor->data().nbytes());
  417. size_t embedding_table_size = weights_[key]->size() * sizeof(float);
  418. if (new_tensor_size != embedding_table_size) {
  419. MS_LOG(EXCEPTION) << "Shape of embedding table can't match. New tensor size:" << new_tensor_size
  420. << ", embedding_table size:" << embedding_table_size;
  421. }
  422. MS_EXCEPTION_IF_NULL(new_tensor_data_ptr);
  423. MS_EXCEPTION_IF_NULL(weights_[key]->data());
  424. int64_t ret = memcpy_s(new_tensor_data_ptr, new_tensor_size, weights_[key]->data(), embedding_table_size);
  425. if (ret != 0) {
  426. MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")";
  427. return;
  428. }
  429. auto paramter_tensor_ptr = embedding_table.second->default_param();
  430. MS_EXCEPTION_IF_NULL(paramter_tensor_ptr);
  431. paramter_tensor_ptr->cast<tensor::TensorPtr>()->AssignValue(*new_tensor);
  432. }
  433. }
  434. void ParameterServer::ServerHandler::Init() {
  435. handlers_[kInitWeightsCmd] = &ServerHandler::HandleInitWeights;
  436. handlers_[kInitWeightToOptimIdCmd] = &ServerHandler::HandleInitWeightToOptimId;
  437. handlers_[kInitOptimInputsShapeCmd] = &ServerHandler::HandleInitInputsShape;
  438. handlers_[kInitEmbeddingsCmd] = &ServerHandler::HandleInitEmbeddings;
  439. handlers_[kCheckReadyForPushCmd] = &ServerHandler::HandleCheckReadyForPush;
  440. handlers_[kCheckReadyForPullCmd] = &ServerHandler::HandleCheckReadyForPull;
  441. handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup;
  442. handlers_[kUpdateEmbeddingsCmd] = &ServerHandler::HandleUpdateEmbeddings;
  443. handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize;
  444. handlers_[kPushCmd] = &ServerHandler::HandlePushReq;
  445. handlers_[kPullCmd] = &ServerHandler::HandlePullReq;
  446. commands_[kInitWeightsCmd] = "kInitWeightsCmd";
  447. commands_[kInitWeightToOptimIdCmd] = "kInitWeightToOptimIdCmd";
  448. commands_[kInitOptimInputsShapeCmd] = "kInitOptimInputsShapeCmd";
  449. commands_[kInitEmbeddingsCmd] = "kInitEmbeddingsCmd";
  450. commands_[kCheckReadyForPushCmd] = "kCheckReadyForPushCmd";
  451. commands_[kCheckReadyForPullCmd] = "kCheckReadyForPullCmd";
  452. commands_[kEmbeddingLookupCmd] = "kEmbeddingLookupCmd";
  453. commands_[kUpdateEmbeddingsCmd] = "kUpdateEmbeddingsCmd";
  454. commands_[kFinalizeCmd] = "kFinalizeCmd";
  455. commands_[kPushCmd] = "kPushCmd";
  456. commands_[kPullCmd] = "kPullCmd";
  457. }
  458. void ParameterServer::ServerHandler::operator()(std::shared_ptr<core::TcpConnection> conn,
  459. std::shared_ptr<core::MessageMeta> meta, DataPtr data, size_t size) {
  460. auto output = std::make_shared<std::vector<unsigned char>>();
  461. if (commands_.count(meta->user_cmd()) == 0) {
  462. MS_LOG(EXCEPTION) << "The command:" << meta->user_cmd() << " is not supported!";
  463. }
  464. MS_LOG(INFO) << "The command is:" << commands_[meta->user_cmd()];
  465. auto &handler_ptr = handlers_[meta->user_cmd()];
  466. (this->*handler_ptr)(data, size, output);
  467. MS_LOG(DEBUG) << "The output size is:" << output->size();
  468. if (output->size() > 0) {
  469. ps_->server_node_->Response(conn, meta, output->data(), output->size());
  470. } else {
  471. // If the size of the output is 0, then constructed an empty string, Because the Response function is a synchronous,
  472. // the res variable will be automatically recycled after calling the Response function
  473. std::string res;
  474. ps_->server_node_->Response(conn, meta, res.data(), res.length());
  475. }
  476. MS_LOG(DEBUG) << "The request id is:" << meta->request_id() << " the current time is:"
  477. << std::chrono::time_point_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now())
  478. .time_since_epoch()
  479. .count();
  480. }
  481. void ParameterServer::ServerHandler::HandlePushReq(DataPtr data, size_t size, VectorPtr res) {
  482. MS_EXCEPTION_IF_NULL(res);
  483. KVMessage input;
  484. input.ParseFromArray(data.get(), size);
  485. Keys keys = {input.keys().begin(), input.keys().end()};
  486. Values values = {input.values().begin(), input.values().end()};
  487. Lengths lens = {input.len().begin(), input.len().end()};
  488. MS_LOG(DEBUG) << "The keys:" << keys << " the values:" << values << " the len:" << lens;
  489. ps_->AccumGrad(keys, values, lens);
  490. }
  491. void ParameterServer::ServerHandler::HandlePullReq(DataPtr data, size_t size, VectorPtr res) {
  492. MS_EXCEPTION_IF_NULL(res);
  493. KVMessage input;
  494. input.ParseFromArray(data.get(), size);
  495. KVMessage res_data;
  496. *res_data.mutable_keys() = input.keys();
  497. Key key = input.keys()[0];
  498. auto weight = ps_->weight(key);
  499. *res_data.mutable_values() = {weight->begin(), weight->end()};
  500. res->resize(res_data.ByteSizeLong());
  501. size_t dest_size = res_data.ByteSizeLong();
  502. size_t src_size = res_data.ByteSizeLong();
  503. int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
  504. if (ret != 0) {
  505. MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
  506. }
  507. }
  508. void ParameterServer::ServerHandler::HandleInitWeights(DataPtr data, size_t size, VectorPtr res) {
  509. std::unique_lock<std::mutex> lock(ps_->mutex());
  510. MS_EXCEPTION_IF_NULL(res);
  511. KVMessage input;
  512. input.ParseFromArray(data.get(), size);
  513. int key_num = input.keys_size();
  514. const float *data_ptr = input.values().data();
  515. size_t pos = 0;
  516. for (int i = 0; i < key_num; i++) {
  517. Key key = input.keys()[i];
  518. size_t data_len = input.len_size() != key_num ? input.values_size() / key_num : input.len()[i];
  519. if (!ps_->HasWeight(key)) {
  520. WeightPtr weight_ptr = std::make_shared<std::vector<float>>(data_ptr + pos, data_ptr + (pos + data_len));
  521. MS_EXCEPTION_IF_NULL(weight_ptr);
  522. ps_->InitWeight(key, weight_ptr);
  523. GradPtr grad_ptr = std::make_shared<std::vector<float>>(data_len, 0);
  524. MS_EXCEPTION_IF_NULL(grad_ptr);
  525. ps_->InitGrad(key, grad_ptr);
  526. }
  527. pos += data_len;
  528. }
  529. }
  530. void ParameterServer::ServerHandler::HandleInitWeightToOptimId(DataPtr data, size_t size, VectorPtr res) {
  531. std::unique_lock<std::mutex> lock(ps_->mutex());
  532. MS_EXCEPTION_IF_NULL(res);
  533. KVMessage input;
  534. input.ParseFromArray(data.get(), size);
  535. size_t key_num = input.keys_size();
  536. for (size_t i = 0; i < key_num; i++) {
  537. Key key = input.keys()[i];
  538. float val = input.values()[i];
  539. if (init_weight_to_optim_[key]) {
  540. continue;
  541. } else {
  542. init_weight_to_optim_[key] = true;
  543. }
  544. ps_->InitWeightKeyToOptims(key, val);
  545. }
  546. }
  547. void ParameterServer::ServerHandler::HandleInitInputsShape(DataPtr data, size_t size, VectorPtr res) {
  548. std::unique_lock<std::mutex> lock(ps_->mutex());
  549. MS_EXCEPTION_IF_NULL(res);
  550. KVMessage input;
  551. input.ParseFromArray(data.get(), size);
  552. const Key &key = input.keys()[0];
  553. if (init_optim_info_[key]) {
  554. return;
  555. } else {
  556. init_optim_info_[key] = true;
  557. }
  558. Keys keys = {input.keys().begin(), input.keys().end()};
  559. Values values = {input.values().begin(), input.values().end()};
  560. Lengths lens = {input.len().begin(), input.len().end()};
  561. ps_->InitOptimInputsShape(keys, values, lens);
  562. }
  563. void ParameterServer::ServerHandler::HandleInitEmbeddings(DataPtr data, size_t size, VectorPtr res) {
  564. std::unique_lock<std::mutex> lock(ps_->mutex());
  565. EmbeddingTableMeta embedding_table_meta;
  566. embedding_table_meta.ParseFromArray(data.get(), size);
  567. const Key &key = embedding_table_meta.key();
  568. MS_LOG(INFO) << "Initializing embedding table for key:" << key;
  569. std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes =
  570. std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>();
  571. MS_EXCEPTION_IF_NULL(shapes);
  572. std::shared_ptr<std::vector<size_t>> input_shape = std::make_shared<std::vector<size_t>>(
  573. embedding_table_meta.input_shape().begin(), embedding_table_meta.input_shape().end());
  574. MS_EXCEPTION_IF_NULL(input_shape);
  575. std::shared_ptr<std::vector<size_t>> indices_shape = std::make_shared<std::vector<size_t>>(
  576. embedding_table_meta.indices_shape().begin(), embedding_table_meta.indices_shape().end());
  577. MS_EXCEPTION_IF_NULL(indices_shape);
  578. std::shared_ptr<std::vector<size_t>> output_shape = std::make_shared<std::vector<size_t>>(
  579. embedding_table_meta.output_shape().begin(), embedding_table_meta.output_shape().end());
  580. MS_EXCEPTION_IF_NULL(output_shape);
  581. shapes->push_back(input_shape);
  582. shapes->push_back(indices_shape);
  583. shapes->push_back(output_shape);
  584. const ParamInitInfoMessage &info = embedding_table_meta.info();
  585. ParamInitInfo param_init_info;
  586. if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
  587. param_init_info.param_type_ = static_cast<ParamType>(info.param_type());
  588. if (param_init_info.param_type_ == kWeight) {
  589. param_init_info.global_seed_ = info.global_seed();
  590. param_init_info.op_seed_ = info.op_seed();
  591. } else if (param_init_info.param_type_ == kAccumulation) {
  592. param_init_info.init_val_ = info.init_val();
  593. }
  594. }
  595. ps_->InitEmbeddingTable(key, shapes, param_init_info);
  596. }
  597. void ParameterServer::ServerHandler::HandleCheckReadyForPush(DataPtr data, size_t size, VectorPtr res) {
  598. MS_EXCEPTION_IF_NULL(res);
  599. KVMessage input;
  600. input.ParseFromArray(data.get(), size);
  601. const Key &key = input.keys()[0];
  602. bool ready = ps_->ReadyForPush(key);
  603. MS_LOG(INFO) << "The ready is:" << ready;
  604. KVMessage res_data;
  605. res_data.add_keys(key);
  606. res_data.add_values(ready);
  607. res->resize(res_data.ByteSizeLong());
  608. size_t dest_size = res_data.ByteSizeLong();
  609. size_t src_size = res_data.ByteSizeLong();
  610. int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
  611. if (ret != 0) {
  612. MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
  613. }
  614. }
  615. void ParameterServer::ServerHandler::HandleCheckReadyForPull(DataPtr data, size_t size, VectorPtr res) {
  616. MS_EXCEPTION_IF_NULL(res);
  617. KVMessage input;
  618. input.ParseFromArray(data.get(), size);
  619. const Key &key = input.keys()[0];
  620. bool ready = ps_->ReadyForPull(key);
  621. KVMessage res_data;
  622. res_data.add_keys(key);
  623. res_data.add_values(ready);
  624. res->resize(res_data.ByteSizeLong());
  625. size_t dest_size = res_data.ByteSizeLong();
  626. size_t src_size = res_data.ByteSizeLong();
  627. int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
  628. if (ret != 0) {
  629. MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
  630. }
  631. }
  632. void ParameterServer::ServerHandler::HandleEmbeddingLookup(DataPtr data, size_t size, VectorPtr res) {
  633. MS_EXCEPTION_IF_NULL(res);
  634. EmbeddingTableLookup input;
  635. input.ParseFromArray(data.get(), size);
  636. const Key &key = input.key();
  637. KVMessage res_data;
  638. std::vector<Key> keys = {input.keys().begin(), input.keys().end()};
  639. *res_data.mutable_keys() = {input.keys().begin(), input.keys().end()};
  640. ps_->DoEmbeddingLookup(key, keys, &res_data);
  641. res->resize(res_data.ByteSizeLong());
  642. size_t dest_size = res_data.ByteSizeLong();
  643. size_t src_size = res_data.ByteSizeLong();
  644. int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
  645. if (ret != 0) {
  646. MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
  647. }
  648. }
  649. void ParameterServer::ServerHandler::HandleUpdateEmbeddings(DataPtr data, size_t size, VectorPtr res) {
  650. std::unique_lock<std::mutex> lock(ps_->mutex());
  651. MS_EXCEPTION_IF_NULL(res);
  652. KVMessage input;
  653. input.ParseFromArray(data.get(), size);
  654. const Key &key = input.keys()[0];
  655. const LookupIds &lookup_ids = {input.keys().begin() + 1, input.keys().end()};
  656. const Values &update_vals = {input.values().begin(), input.values().end()};
  657. ps_->UpdateEmbeddings(key, lookup_ids, update_vals);
  658. }
  659. void ParameterServer::ServerHandler::HandleFinalize(DataPtr data, size_t size, VectorPtr res) {
  660. MS_EXCEPTION_IF_NULL(res);
  661. ps_->Finalize();
  662. }
  663. } // namespace ps
  664. } // namespace mindspore