You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parameter_server.cc 34 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "ps/parameter_server.h"
  17. #include <algorithm>
  18. #include <thread>
  19. namespace mindspore {
  20. namespace ps {
  21. static const uint32_t kMaxThreadNum = 16;
  22. static const uint32_t kCPUCoreNum = std::thread::hardware_concurrency();
  23. void ParameterServer::Run(const FuncGraphPtr &func_graph) {
  24. MS_EXCEPTION_IF_NULL(func_graph);
  25. MS_LOG(INFO) << "PServer starts connecting to scheduler and workers...";
  26. server_node_ = std::make_shared<core::ServerNode>();
  27. MS_LOG(INFO) << "PServer connected successfully.";
  28. if (!PSContext::instance()->is_server()) {
  29. MS_LOG(INFO) << "This is not the Server node.";
  30. return;
  31. }
  32. Init(func_graph);
  33. server_node_->Start();
  34. PSContext::instance()->SetPSRankId(server_node_->rank_id());
  35. thread_->join();
  36. SyncEmbeddingTables();
  37. MS_LOG(INFO) << "PServer finished updating models, starts finalizing...";
  38. server_node_->Finish();
  39. if (!server_node_->Stop()) {
  40. MS_LOG(WARNING) << "Parameter server stop failed.";
  41. }
  42. MS_LOG(INFO) << "PServer finalized successfully.";
  43. }
  44. bool ParameterServer::Init(const FuncGraphPtr &func_graph) {
  45. pserver_num_ = std::strtol(mindspore::common::GetEnv(kEnvPServerNum).c_str(), nullptr, kBase);
  46. worker_num_ = std::strtol(mindspore::common::GetEnv(kEnvWorkerNum).c_str(), nullptr, kBase);
  47. func_graph_ = func_graph;
  48. handler_.reset(new ServerHandler(this));
  49. handler_->Init();
  50. InitOptimInfoBuilders();
  51. server_node_->set_handler(*handler_);
  52. server_node_->RegisterEventCallback(core::ClusterEvent::SCHEDULER_TIMEOUT, [this]() {
  53. MS_LOG(ERROR) << "Trigger timeout event: SCHEDULER_TIMEOUT begin to exit the system!";
  54. this->Finalize();
  55. });
  56. server_node_->RegisterEventCallback(core::ClusterEvent::NODE_TIMEOUT, [this]() {
  57. MS_LOG(ERROR) << "Trigger timeout event: NODE_TIMEOUT begin to exit the system!";
  58. this->Finalize();
  59. });
  60. thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this));
  61. GetEmbeddingTableParamPtr();
  62. return true;
  63. }
  64. void ParameterServer::InitOptimInfoBuilders() {
  65. std::shared_ptr<OptimizerInfoBuilder> momentum_info_builder = std::make_shared<MomentumOptimInfoBuilder>(worker_num_);
  66. std::shared_ptr<OptimizerInfoBuilder> sparse_adam_info_builder =
  67. std::make_shared<SparseAdamOptimInfoBuilder>(worker_num_);
  68. std::shared_ptr<OptimizerInfoBuilder> sparse_ftrl_info_builder =
  69. std::make_shared<SparseFtrlOptimInfoBuilder>(worker_num_);
  70. optim_info_builders_[kApplyMomentum] = momentum_info_builder;
  71. optim_info_builders_[kSparseAdam] = sparse_adam_info_builder;
  72. optim_info_builders_[kSparseFtrl] = sparse_ftrl_info_builder;
  73. }
  74. void ParameterServer::InitWeightKeyToOptims(const Key &key, const int64_t &optim_id) {
  75. if (weight_key_to_optims_.count(key) > 0 || Util::optimizer_name(optim_id) == "") {
  76. return;
  77. }
  78. weight_key_to_optims_[key] = Util::optimizer_name(optim_id);
  79. weight_key_to_optim_op_[key] = Util::optimizer_node_name(optim_id);
  80. MS_LOG(INFO) << "Initializing optimizer id for key:" << key << ", optimizer name:" << weight_key_to_optims_[key]
  81. << ", optimizer op name:" << weight_key_to_optim_op_[key];
  82. }
  83. void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths) {
  84. InputsShapePtr inputs_shape = std::make_shared<InputsShape>();
  85. MS_EXCEPTION_IF_NULL(inputs_shape);
  86. InputsShapePtr original_inputs_shape = std::make_shared<InputsShape>();
  87. MS_EXCEPTION_IF_NULL(original_inputs_shape);
  88. size_t val_idx = 0;
  89. const Key &key = keys[0];
  90. MS_LOG(INFO) << "Initializing optimizer inputs shape for key:" << key;
  91. if (optim_inputs_shape_.count(key) == 0) {
  92. original_optim_inputs_shape_[key] = original_inputs_shape;
  93. optim_inputs_shape_[key] = inputs_shape;
  94. }
  95. for (size_t i = 0; i < keys.size(); i++) {
  96. auto shape = std::make_shared<std::vector<size_t>>();
  97. MS_EXCEPTION_IF_NULL(shape);
  98. auto original_shape = std::make_shared<std::vector<size_t>>();
  99. MS_EXCEPTION_IF_NULL(original_shape);
  100. inputs_shape->push_back(shape);
  101. original_inputs_shape->push_back(original_shape);
  102. for (int64_t j = 0; j < lengths[i]; j++) {
  103. shape->push_back(values[val_idx]);
  104. original_shape->push_back(values[val_idx++]);
  105. }
  106. }
  107. if (weight_key_to_optims_.count(key) > 0) {
  108. const std::string &optim_name = weight_key_to_optims_[key];
  109. const std::string &optim_op_name = weight_key_to_optim_op_[key];
  110. if (optimizers_.count(key) == 0 && optim_inputs_shape_.count(key) > 0) {
  111. const CNodePtr cnode = GetCNode(optim_op_name);
  112. MS_EXCEPTION_IF_NULL(cnode);
  113. if (optim_name == kSparseAdam) {
  114. std::shared_ptr<PServerKernel> optimizer =
  115. std::make_shared<kernel::ps::SparseApplyAdamPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
  116. optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
  117. optimizers_[key] = optimizer;
  118. } else if (optim_name == kSparseLazyAdam) {
  119. std::shared_ptr<PServerKernel> optimizer =
  120. std::make_shared<kernel::ps::SparseApplyLazyAdamPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
  121. optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
  122. optimizers_[key] = optimizer;
  123. } else if (optim_name == kApplyMomentum) {
  124. std::shared_ptr<PServerKernel> optimizer =
  125. std::make_shared<kernel::ps::ApplyMomentumPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
  126. optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
  127. optimizers_[key] = optimizer;
  128. } else if (optim_name == kSparseFtrl) {
  129. std::shared_ptr<PServerKernel> optimizer =
  130. std::make_shared<kernel::ps::SparseApplyFtrlPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
  131. optimizer->InitKernel(cnode, optim_inputs_shape_[key]);
  132. optimizers_[key] = optimizer;
  133. }
  134. }
  135. }
  136. }
  137. void ParameterServer::InitWeight(const Key &key, const WeightPtr &weight) {
  138. MS_EXCEPTION_IF_NULL(weight);
  139. if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) {
  140. MS_LOG(INFO) << "Initializing weight for key " << key << ", server rank " << server_node_->rank_id();
  141. weights_[key] = weight;
  142. tokens_[key] = 0;
  143. is_embedding_[key] = false;
  144. }
  145. }
  146. void ParameterServer::InitGrad(const Key &key, const GradPtr &grad) {
  147. MS_EXCEPTION_IF_NULL(grad);
  148. if (grads_.count(key) == 0) {
  149. grads_[key] = grad;
  150. grads_accum_counter_[key] = 0;
  151. }
  152. }
  153. namespace {
  154. // Initialize accumulation by multithreading parallelism.
  155. void InitAccumParallel(float init_value, size_t total_len, float *embedding_data) {
  156. MS_EXCEPTION_IF_NULL(embedding_data);
  157. auto init_task = [](float value, size_t task_len, float *data) {
  158. for (size_t i = 0; i < task_len; i++) {
  159. data[i] = value;
  160. }
  161. };
  162. size_t thread_num = std::max(kMaxThreadNum, kCPUCoreNum);
  163. if (total_len <= thread_num) {
  164. thread_num = 1;
  165. }
  166. std::vector<std::thread> threads(thread_num);
  167. size_t task_offset = 0;
  168. for (size_t i = 0; i < thread_num; ++i) {
  169. // The value of thread_num is >= 1.
  170. size_t task_len = total_len / thread_num + (i < (total_len % thread_num) ? 1 : 0);
  171. threads[i] = std::thread(init_task, init_value, task_len, embedding_data + task_offset);
  172. task_offset += task_len;
  173. }
  174. for (size_t i = 0; i < thread_num; i++) {
  175. threads[i].join();
  176. }
  177. }
  178. void CopyTensorData(void *dest_ptr, size_t tensor_size, const void *src_ptr) {
  179. MS_EXCEPTION_IF_NULL(dest_ptr);
  180. MS_EXCEPTION_IF_NULL(src_ptr);
  181. char *dest = reinterpret_cast<char *>(dest_ptr);
  182. const char *src = reinterpret_cast<const char *>(src_ptr);
  183. // The security memcpy function 'memcpy_s' limits the value of the second parameter 'destMax' not to be greater than
  184. // SECUREC_MEM_MAX_LEN. If tensor size(buffer length) is greater than SECUREC_MEM_MAX_LEN, the tensor should be cut
  185. // into segments to copy.
  186. for (size_t offset = 0; offset < tensor_size; offset += SECUREC_MEM_MAX_LEN) {
  187. size_t copy_len = std::min(tensor_size - offset, SECUREC_MEM_MAX_LEN);
  188. size_t dest_len = copy_len;
  189. int ret = memcpy_s(dest + offset, dest_len, src + offset, copy_len);
  190. if (ret != 0) {
  191. MS_LOG(EXCEPTION) << "Failed to memcpy tensor, errorno(" << ret << ")";
  192. }
  193. }
  194. }
  195. } // namespace
  196. void ParameterServer::InitEmbeddingTable(
  197. const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes,
  198. const ParamInitInfo &param_init_info) {
  199. MS_EXCEPTION_IF_NULL(shapes);
  200. if (weights_.count(key) == 0) {
  201. std::shared_ptr<PServerKernel> lookup =
  202. std::make_shared<kernel::ps::EmbeddingLookUpPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_);
  203. lookup->InitKernel(shapes);
  204. embedding_lookup_ops_[key] = lookup;
  205. // Init embedding weight
  206. const std::vector<size_t> &input_shapes = lookup->input_sizes();
  207. size_t total_dims =
  208. std::accumulate(input_shapes.begin(), input_shapes.end(), IntToSize(1), std::multiplies<size_t>());
  209. WeightPtr embedding = std::make_shared<Weight>(total_dims, 0);
  210. MS_EXCEPTION_IF_NULL(embedding);
  211. float *embedding_data = embedding->data();
  212. std::default_random_engine engine;
  213. std::normal_distribution<float> random(0, kStdDev);
  214. if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
  215. CacheEmbeddingTableParamPtr();
  216. if (param_init_info.param_type_ == kWeight) {
  217. const std::string &param_name = param_init_info.param_name_;
  218. auto iter = embedding_parameter_tables_.find(param_name);
  219. if (iter == embedding_parameter_tables_.end()) {
  220. MS_LOG(EXCEPTION) << "Can not find parameter info for: " << param_name;
  221. }
  222. // Cache embedding table parameter by weight key to parameter node pointer.
  223. (void)embedding_tables_.emplace(key, iter->second);
  224. InitRandomNormal(0, kStdDev, input_shapes, param_init_info.global_seed_, param_init_info.op_seed_,
  225. embedding_data);
  226. } else if (param_init_info.param_type_ == kAccumulation) {
  227. InitAccumParallel(param_init_info.init_val_, total_dims, embedding_data);
  228. }
  229. } else {
  230. for (size_t i = 0; i < total_dims; i++) {
  231. embedding_data[i] = random(engine);
  232. }
  233. }
  234. weights_[key] = embedding;
  235. MS_LOG(DEBUG) << "The key:" << key << " the embedding:" << *embedding;
  236. tokens_[key] = 0;
  237. is_embedding_[key] = true;
  238. grads_accum_counter_[key] = 0;
  239. }
  240. }
  241. bool ParameterServer::HasWeight(const Key &key) { return (weights_.count(key) > 0 && !is_embedding_.count(key)); }
  242. void ParameterServer::Finalize() {
  243. running_ = false;
  244. apply_grads_cv_.notify_one();
  245. }
  246. void ParameterServer::UpdateWeights() {
  247. while (true) {
  248. MS_LOG(INFO) << "The running is:" << running_ << " the ready is:" << this->ReadyForUpdateWeights();
  249. std::unique_lock<std::mutex> lock(mutex_);
  250. apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights() || !running_; });
  251. if (!running_) {
  252. break;
  253. }
  254. for (auto iter = weights_.begin(); iter != weights_.end(); iter++) {
  255. Key key = iter->first;
  256. WeightPtr weight_ptr = iter->second;
  257. std::shared_ptr<PServerKernel> optimizer = nullptr;
  258. if (weight_key_to_optims_.count(key) > 0) {
  259. optimizer = optimizers_[key];
  260. }
  261. MS_EXCEPTION_IF_NULL(optimizer);
  262. std::shared_ptr<OptimizerInfo> optim_info = optim_infos_[key];
  263. if (optim_info != nullptr) {
  264. const std::vector<kernel::AddressPtr> &inputs = optim_info->inputs();
  265. const std::vector<kernel::AddressPtr> &workspaces = optim_info->workspaces();
  266. const std::vector<kernel::AddressPtr> &outputs = optim_info->outputs();
  267. std::vector<std::vector<size_t>> shapes = {};
  268. std::vector<size_t> indices_shape = {};
  269. indices_shape.emplace_back(optim_info->indice_size());
  270. shapes.push_back(indices_shape);
  271. if (original_optim_inputs_shape_.count(key) != 0) {
  272. std::transform((*(original_optim_inputs_shape_[key])).begin(), (*(original_optim_inputs_shape_[key])).end(),
  273. std::back_inserter(shapes),
  274. [](const std::shared_ptr<std::vector<size_t>> &input_shapes) -> std::vector<size_t> {
  275. return *input_shapes;
  276. });
  277. }
  278. optimizer->ReInit(shapes);
  279. optim_info->ComputeMean(shapes, worker_num_, pserver_num_, server_node_->rank_id());
  280. optimizer->Execute(inputs, workspaces, outputs);
  281. optim_info->Reset();
  282. }
  283. if (!is_embedding_[key]) {
  284. tokens_[key] = worker_num_;
  285. }
  286. }
  287. ResetGradAccumCount();
  288. }
  289. }
  290. void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const Lengths &lengths) {
  291. std::unique_lock<std::mutex> lock(mutex_);
  292. const Key &key = keys[0];
  293. bool no_sparse_grad = values.size() == 1 && values[0] == kGradValue;
  294. if (!no_sparse_grad) {
  295. std::shared_ptr<OptimizerInfo> optim_info = optim_infos_[key];
  296. // Create or update the optimizer info
  297. if (optim_info == nullptr) {
  298. const std::shared_ptr<OptimizerInfoBuilder> &builder = optim_info_builders_[weight_key_to_optims_[key]];
  299. std::shared_ptr<kernel::ps::PServerKernel> pserver_kernel = optimizers_[key];
  300. if (pserver_kernel == nullptr) {
  301. MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key];
  302. }
  303. MS_EXCEPTION_IF_NULL(pserver_kernel);
  304. OptimizerInfo *optim = builder->Build(pserver_kernel, weights_[key], keys, values, lengths,
  305. optim_inputs_shape_[key], worker_num_, is_embedding_[key]);
  306. optim_info.reset(optim);
  307. optim_infos_[key] = optim_info;
  308. } else {
  309. optim_info->Update(values, lengths);
  310. optim_info->Accumulate(values, lengths);
  311. }
  312. }
  313. grads_accum_counter_[key] += 1;
  314. if (grads_accum_counter_[key] == worker_num_) {
  315. grad_accum_count_++;
  316. }
  317. if (ReadyForUpdateWeights()) {
  318. apply_grads_cv_.notify_one();
  319. }
  320. }
  321. WeightPtr ParameterServer::weight(const Key &key) {
  322. std::unique_lock<std::mutex> lock(mutex_);
  323. if (weights_.count(key) == 0) {
  324. MS_LOG(EXCEPTION) << "Invalid weight key " << key;
  325. }
  326. WeightPtr weight_ptr = weights_[key];
  327. MS_EXCEPTION_IF_NULL(weight_ptr);
  328. WeightPtr copy_weight_ptr = std::make_shared<std::vector<float>>(weight_ptr->size(), 0);
  329. MS_EXCEPTION_IF_NULL(copy_weight_ptr);
  330. copy_weight_ptr = weight_ptr;
  331. tokens_[key] -= 1;
  332. return copy_weight_ptr;
  333. }
  334. void ParameterServer::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res) {
  335. std::unique_lock<std::mutex> lock(mutex_);
  336. MS_EXCEPTION_IF_NULL(res);
  337. if (weights_.count(key) == 0) {
  338. MS_LOG(ERROR) << "Invalid embedding table key " << key;
  339. return;
  340. }
  341. if (embedding_lookup_ops_.count(key) == 0) {
  342. MS_LOG(ERROR) << "Invalid embedding lookup op key " << key;
  343. return;
  344. }
  345. WeightPtr table_ptr = weights_[key];
  346. MS_EXCEPTION_IF_NULL(table_ptr);
  347. std::shared_ptr<PServerKernel> table_lookup_op = embedding_lookup_ops_[key];
  348. MS_EXCEPTION_IF_NULL(table_lookup_op);
  349. // Update shapes of lookup operator
  350. std::vector<std::vector<size_t>> shapes = {};
  351. std::vector<size_t> indices_shape = {};
  352. indices_shape.emplace_back(lookup_ids.size());
  353. shapes.push_back(indices_shape);
  354. table_lookup_op->ReInit(shapes);
  355. const std::vector<size_t> output_shapes = table_lookup_op->output_sizes();
  356. std::vector<kernel::AddressPtr> inputs;
  357. AddressPtr embedding_table = std::make_shared<kernel::Address>();
  358. MS_EXCEPTION_IF_NULL(embedding_table);
  359. AddressPtr indices = std::make_shared<kernel::Address>();
  360. MS_EXCEPTION_IF_NULL(indices);
  361. inputs.push_back(embedding_table);
  362. inputs.push_back(indices);
  363. embedding_table->addr = table_ptr->data();
  364. embedding_table->size = table_ptr->size() * sizeof(float);
  365. std::unique_ptr<int[]> tmp_ids = std::make_unique<int[]>(lookup_ids.size());
  366. MS_EXCEPTION_IF_NULL(tmp_ids);
  367. for (size_t i = 0; i < lookup_ids.size(); i++) {
  368. tmp_ids[i] = static_cast<int>(lookup_ids[i]);
  369. }
  370. indices->addr = tmp_ids.get();
  371. indices->size = lookup_ids.size() * sizeof(int);
  372. std::vector<kernel::AddressPtr> workspaces;
  373. std::vector<kernel::AddressPtr> outputs;
  374. AddressPtr output = std::make_shared<kernel::Address>();
  375. MS_EXCEPTION_IF_NULL(output);
  376. std::shared_ptr<Values> addr = std::make_shared<Values>(output_shapes[0] / sizeof(float), 0);
  377. MS_EXCEPTION_IF_NULL(addr);
  378. output->addr = addr->data();
  379. output->size = output_shapes[0];
  380. outputs.push_back(output);
  381. table_lookup_op->Execute(inputs, workspaces, outputs);
  382. *res->mutable_values() = {addr->begin(), addr->end()};
  383. res->add_len(res->values_size());
  384. }
  385. void ParameterServer::UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals) {
  386. if (weights_.count(key) == 0) {
  387. MS_LOG(ERROR) << "Invalid embedding table key " << key;
  388. return;
  389. }
  390. if (embedding_lookup_ops_.count(key) == 0) {
  391. MS_LOG(ERROR) << "Invalid embedding lookup op key " << key;
  392. return;
  393. }
  394. WeightPtr table_ptr = weights_[key];
  395. MS_EXCEPTION_IF_NULL(table_ptr);
  396. std::shared_ptr<PServerKernel> table_lookup_op = embedding_lookup_ops_[key];
  397. MS_EXCEPTION_IF_NULL(table_lookup_op);
  398. table_lookup_op->UpdateEmbeddings(table_ptr->data(), lookup_ids.data(), vals.data(), lookup_ids.size());
  399. }
  400. inline bool ParameterServer::ReadyForUpdateWeights() const {
  401. return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size();
  402. }
  403. inline bool ParameterServer::ReadyForPush(const Key &key) {
  404. std::unique_lock<std::mutex> lock(mutex_);
  405. if (weights_.empty()) {
  406. MS_LOG(EXCEPTION) << "The weights in server is empty. Many reasons could cause this: 1.The Worker didn't send "
  407. "kInitWeightsCmd command. 2.The Server failed to initialize weights.";
  408. }
  409. return grad_accum_count_ < weights_.size() && tokens_[key] == 0;
  410. }
  411. inline bool ParameterServer::ReadyForPull(const Key &key) {
  412. std::unique_lock<std::mutex> lock(mutex_);
  413. if (tokens_.count(key) == 0 || weights_[key] == 0) {
  414. MS_LOG(EXCEPTION) << "Invalid weight key " << key;
  415. }
  416. MS_LOG(INFO) << "ReadyForPull: " << (tokens_[key] > 0);
  417. return tokens_[key] > 0;
  418. }
  419. inline void ParameterServer::ResetGradAccumCount() {
  420. grad_accum_count_ = 0;
  421. for (auto iter = grads_accum_counter_.begin(); iter != grads_accum_counter_.end(); iter++) {
  422. grads_accum_counter_[iter->first] = 0;
  423. }
  424. }
  425. const CNodePtr ParameterServer::GetCNode(const std::string &name) const {
  426. std::list<CNodePtr> cnodes = func_graph_->GetOrderedCnodes();
  427. for (CNodePtr cnode : cnodes) {
  428. MS_EXCEPTION_IF_NULL(cnode);
  429. std::string fullname = cnode->fullname_with_scope();
  430. if (fullname.find(name) != std::string::npos && fullname.find("Push") != std::string::npos) {
  431. return cnode;
  432. }
  433. }
  434. return nullptr;
  435. }
  436. inline std::mutex &ParameterServer::mutex() { return mutex_; }
  437. void ParameterServer::GetEmbeddingTableParamPtr() {
  438. if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
  439. return;
  440. }
  441. MS_EXCEPTION_IF_NULL(func_graph_);
  442. auto cnodes = func_graph_->GetOrderedCnodes();
  443. Key count = 0;
  444. for (auto cnode : cnodes) {
  445. MS_EXCEPTION_IF_NULL(cnode);
  446. std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
  447. if (cnode_name == kEmbeddingLookupOpName || cnode_name == kGatherV2OpName || cnode_name == kSparseGatherV2OpName) {
  448. auto embedding_table = AnfAlgo::GetInputNode(cnode, 0);
  449. if (IsPrimitiveCNode(embedding_table, prim::kPrimLoad)) {
  450. auto embedding_cnode = embedding_table->cast<CNodePtr>();
  451. embedding_table = AnfAlgo::GetInputNode(embedding_cnode, 0);
  452. }
  453. MS_EXCEPTION_IF_NULL(embedding_table);
  454. if (embedding_table->isa<Parameter>()) {
  455. MS_LOG(INFO) << "Embedding table name is " << embedding_table->fullname_with_scope() << ", key is " << count;
  456. (void)embedding_tables_.emplace(count, embedding_table->cast<ParameterPtr>());
  457. count++;
  458. }
  459. }
  460. }
  461. }
  462. void ParameterServer::CacheEmbeddingTableParamPtr() {
  463. if (embedding_param_ptr_cached_) {
  464. return;
  465. }
  466. MS_EXCEPTION_IF_NULL(func_graph_);
  467. auto cnodes = func_graph_->GetOrderedCnodes();
  468. for (auto cnode : cnodes) {
  469. MS_EXCEPTION_IF_NULL(cnode);
  470. std::string cnode_name = AnfAlgo::GetCNodeName(cnode);
  471. if (cnode_name != kGatherV2OpName && cnode_name != kSparseGatherV2OpName) {
  472. continue;
  473. }
  474. auto embedding_table = AnfAlgo::GetInputNode(cnode, 0);
  475. if (IsPrimitiveCNode(embedding_table, prim::kPrimLoad)) {
  476. auto embedding_cnode = embedding_table->cast<CNodePtr>();
  477. embedding_table = AnfAlgo::GetInputNode(embedding_cnode, 0);
  478. }
  479. MS_EXCEPTION_IF_NULL(embedding_table);
  480. if (embedding_table->isa<Parameter>()) {
  481. (void)embedding_parameter_tables_.emplace(embedding_table->fullname_with_scope(),
  482. embedding_table->cast<ParameterPtr>());
  483. }
  484. }
  485. embedding_param_ptr_cached_ = true;
  486. }
  487. void ParameterServer::SyncEmbeddingTables() {
  488. for (auto embedding_table : embedding_tables_) {
  489. Key key = embedding_table.first;
  490. if (embedding_lookup_ops_.count(key) == 0) {
  491. MS_LOG(WARNING) << "Can't find look up PS kernel for key " << key;
  492. continue;
  493. }
  494. auto lookup = embedding_lookup_ops_[key];
  495. const std::vector<size_t> &input_shapes = lookup->input_sizes();
  496. std::vector<int64_t> new_tensor_shape(input_shapes.begin(), input_shapes.end());
  497. tensor::TensorPtr new_tensor = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, new_tensor_shape);
  498. MS_EXCEPTION_IF_NULL(new_tensor);
  499. float *new_tensor_data_ptr = reinterpret_cast<float *>(new_tensor->data_c());
  500. size_t new_tensor_size = static_cast<size_t>(new_tensor->data().nbytes());
  501. size_t embedding_table_size = weights_[key]->size() * sizeof(float);
  502. if (new_tensor_size != embedding_table_size) {
  503. MS_LOG(EXCEPTION) << "Shape of embedding table can't match. New tensor size:" << new_tensor_size
  504. << ", embedding_table size:" << embedding_table_size;
  505. }
  506. MS_EXCEPTION_IF_NULL(new_tensor_data_ptr);
  507. MS_EXCEPTION_IF_NULL(weights_[key]->data());
  508. CopyTensorData(new_tensor_data_ptr, new_tensor_size, weights_[key]->data());
  509. auto paramter_tensor_ptr = embedding_table.second->default_param();
  510. MS_EXCEPTION_IF_NULL(paramter_tensor_ptr);
  511. paramter_tensor_ptr->cast<tensor::TensorPtr>()->AssignValue(*new_tensor);
  512. }
  513. }
  514. void ParameterServer::ServerHandler::Init() {
  515. handlers_[kInitWeightsCmd] = &ServerHandler::HandleInitWeights;
  516. handlers_[kInitWeightToOptimIdCmd] = &ServerHandler::HandleInitWeightToOptimId;
  517. handlers_[kInitOptimInputsShapeCmd] = &ServerHandler::HandleInitInputsShape;
  518. handlers_[kInitEmbeddingsCmd] = &ServerHandler::HandleInitEmbeddings;
  519. handlers_[kCheckReadyForPushCmd] = &ServerHandler::HandleCheckReadyForPush;
  520. handlers_[kCheckReadyForPullCmd] = &ServerHandler::HandleCheckReadyForPull;
  521. handlers_[kEmbeddingLookupCmd] = &ServerHandler::HandleEmbeddingLookup;
  522. handlers_[kUpdateEmbeddingsCmd] = &ServerHandler::HandleUpdateEmbeddings;
  523. handlers_[kFinalizeCmd] = &ServerHandler::HandleFinalize;
  524. handlers_[kPushCmd] = &ServerHandler::HandlePushReq;
  525. handlers_[kPullCmd] = &ServerHandler::HandlePullReq;
  526. commands_[kInitWeightsCmd] = "kInitWeightsCmd";
  527. commands_[kInitWeightToOptimIdCmd] = "kInitWeightToOptimIdCmd";
  528. commands_[kInitOptimInputsShapeCmd] = "kInitOptimInputsShapeCmd";
  529. commands_[kInitEmbeddingsCmd] = "kInitEmbeddingsCmd";
  530. commands_[kCheckReadyForPushCmd] = "kCheckReadyForPushCmd";
  531. commands_[kCheckReadyForPullCmd] = "kCheckReadyForPullCmd";
  532. commands_[kEmbeddingLookupCmd] = "kEmbeddingLookupCmd";
  533. commands_[kUpdateEmbeddingsCmd] = "kUpdateEmbeddingsCmd";
  534. commands_[kFinalizeCmd] = "kFinalizeCmd";
  535. commands_[kPushCmd] = "kPushCmd";
  536. commands_[kPullCmd] = "kPullCmd";
  537. }
  538. void ParameterServer::ServerHandler::operator()(const std::shared_ptr<core::TcpConnection> &conn,
  539. const std::shared_ptr<core::MessageMeta> &meta, const DataPtr &data,
  540. size_t size) {
  541. auto output = std::make_shared<std::vector<unsigned char>>();
  542. if (commands_.count(meta->user_cmd()) == 0) {
  543. MS_LOG(EXCEPTION) << "The command:" << meta->user_cmd() << " is not supported!";
  544. }
  545. MS_LOG(INFO) << "The command is:" << commands_[meta->user_cmd()];
  546. auto &handler_ptr = handlers_[meta->user_cmd()];
  547. (this->*handler_ptr)(data, size, output);
  548. MS_LOG(DEBUG) << "The output size is:" << output->size();
  549. if (output->size() > 0) {
  550. ps_->server_node_->Response(conn, meta, output->data(), output->size());
  551. } else {
  552. // If the size of the output is 0, then constructed an empty string, Because the Response function is a synchronous,
  553. // the res variable will be automatically recycled after calling the Response function
  554. std::string res;
  555. ps_->server_node_->Response(conn, meta, res.data(), res.length());
  556. }
  557. MS_LOG(DEBUG) << "The request id is:" << meta->request_id() << " the current time is:"
  558. << std::chrono::time_point_cast<std::chrono::microseconds>(std::chrono::high_resolution_clock::now())
  559. .time_since_epoch()
  560. .count();
  561. }
  562. void ParameterServer::ServerHandler::HandlePushReq(const DataPtr &data, size_t size, const VectorPtr &res) {
  563. MS_EXCEPTION_IF_NULL(res);
  564. KVMessage input;
  565. CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
  566. Keys keys = {input.keys().begin(), input.keys().end()};
  567. Values values = {input.values().begin(), input.values().end()};
  568. Lengths lens = {input.len().begin(), input.len().end()};
  569. MS_LOG(DEBUG) << "The keys:" << keys << " the values:" << values << " the len:" << lens;
  570. ps_->AccumGrad(keys, values, lens);
  571. }
  572. void ParameterServer::ServerHandler::HandlePullReq(const DataPtr &data, size_t size, const VectorPtr &res) {
  573. MS_EXCEPTION_IF_NULL(res);
  574. KVMessage input;
  575. CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
  576. KVMessage res_data;
  577. *res_data.mutable_keys() = input.keys();
  578. Key key = input.keys()[0];
  579. auto weight = ps_->weight(key);
  580. *res_data.mutable_values() = {weight->begin(), weight->end()};
  581. res->resize(res_data.ByteSizeLong());
  582. size_t dest_size = res_data.ByteSizeLong();
  583. size_t src_size = res_data.ByteSizeLong();
  584. int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
  585. if (ret != 0) {
  586. MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
  587. }
  588. }
  589. void ParameterServer::ServerHandler::HandleInitWeights(const DataPtr &data, size_t size, const VectorPtr &res) {
  590. std::unique_lock<std::mutex> lock(ps_->mutex());
  591. MS_EXCEPTION_IF_NULL(res);
  592. KVMessage input;
  593. CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
  594. int key_num = input.keys_size();
  595. const float *data_ptr = input.values().data();
  596. size_t pos = 0;
  597. for (int i = 0; i < key_num; i++) {
  598. Key key = input.keys()[i];
  599. size_t data_len = input.len_size() != key_num ? input.values_size() / key_num : input.len()[i];
  600. if (!ps_->HasWeight(key)) {
  601. WeightPtr weight_ptr = std::make_shared<std::vector<float>>(data_ptr + pos, data_ptr + (pos + data_len));
  602. MS_EXCEPTION_IF_NULL(weight_ptr);
  603. ps_->InitWeight(key, weight_ptr);
  604. GradPtr grad_ptr = std::make_shared<std::vector<float>>(data_len, 0);
  605. MS_EXCEPTION_IF_NULL(grad_ptr);
  606. ps_->InitGrad(key, grad_ptr);
  607. }
  608. pos += data_len;
  609. }
  610. }
  611. void ParameterServer::ServerHandler::HandleInitWeightToOptimId(const DataPtr &data, size_t size, const VectorPtr &res) {
  612. std::unique_lock<std::mutex> lock(ps_->mutex());
  613. MS_EXCEPTION_IF_NULL(res);
  614. KVMessage input;
  615. CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
  616. int key_num = input.keys_size();
  617. for (int i = 0; i < key_num; i++) {
  618. Key key = input.keys()[i];
  619. float val = input.values()[i];
  620. if (init_weight_to_optim_[key]) {
  621. continue;
  622. } else {
  623. init_weight_to_optim_[key] = true;
  624. }
  625. ps_->InitWeightKeyToOptims(key, static_cast<int64_t>(val));
  626. }
  627. }
  628. void ParameterServer::ServerHandler::HandleInitInputsShape(const DataPtr &data, size_t size, const VectorPtr &res) {
  629. std::unique_lock<std::mutex> lock(ps_->mutex());
  630. MS_EXCEPTION_IF_NULL(res);
  631. KVMessage input;
  632. CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
  633. const Key &key = input.keys()[0];
  634. if (init_optim_info_[key]) {
  635. return;
  636. } else {
  637. init_optim_info_[key] = true;
  638. }
  639. Keys keys = {input.keys().begin(), input.keys().end()};
  640. Values values = {input.values().begin(), input.values().end()};
  641. Lengths lens = {input.len().begin(), input.len().end()};
  642. ps_->InitOptimInputsShape(keys, values, lens);
  643. }
  644. void ParameterServer::ServerHandler::HandleInitEmbeddings(const DataPtr &data, size_t size, const VectorPtr &) {
  645. std::unique_lock<std::mutex> lock(ps_->mutex());
  646. EmbeddingTableMeta embedding_table_meta;
  647. CHECK_RETURN_TYPE(embedding_table_meta.ParseFromArray(data.get(), SizeToInt(size)));
  648. const Key &key = embedding_table_meta.key();
  649. MS_LOG(INFO) << "Initializing embedding table for key:" << key;
  650. std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> shapes =
  651. std::make_shared<std::vector<std::shared_ptr<std::vector<size_t>>>>();
  652. MS_EXCEPTION_IF_NULL(shapes);
  653. std::shared_ptr<std::vector<size_t>> input_shape = std::make_shared<std::vector<size_t>>(
  654. embedding_table_meta.input_shape().begin(), embedding_table_meta.input_shape().end());
  655. MS_EXCEPTION_IF_NULL(input_shape);
  656. std::shared_ptr<std::vector<size_t>> indices_shape = std::make_shared<std::vector<size_t>>(
  657. embedding_table_meta.indices_shape().begin(), embedding_table_meta.indices_shape().end());
  658. MS_EXCEPTION_IF_NULL(indices_shape);
  659. std::shared_ptr<std::vector<size_t>> output_shape = std::make_shared<std::vector<size_t>>(
  660. embedding_table_meta.output_shape().begin(), embedding_table_meta.output_shape().end());
  661. MS_EXCEPTION_IF_NULL(output_shape);
  662. shapes->push_back(input_shape);
  663. shapes->push_back(indices_shape);
  664. shapes->push_back(output_shape);
  665. const ParamInitInfoMessage &info = embedding_table_meta.info();
  666. ParamInitInfo param_init_info;
  667. if (ps::PsDataPrefetch::GetInstance().cache_enable()) {
  668. param_init_info.param_name_ = info.param_name();
  669. param_init_info.param_type_ = static_cast<ParamType>(info.param_type());
  670. if (param_init_info.param_type_ == kWeight) {
  671. param_init_info.global_seed_ = info.global_seed();
  672. param_init_info.op_seed_ = info.op_seed();
  673. } else if (param_init_info.param_type_ == kAccumulation) {
  674. param_init_info.init_val_ = info.init_val();
  675. }
  676. }
  677. ps_->InitEmbeddingTable(key, shapes, param_init_info);
  678. }
  679. void ParameterServer::ServerHandler::HandleCheckReadyForPush(const DataPtr &data, size_t size, const VectorPtr &res) {
  680. MS_EXCEPTION_IF_NULL(res);
  681. KVMessage input;
  682. CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
  683. const Key &key = input.keys()[0];
  684. bool ready = ps_->ReadyForPush(key);
  685. MS_LOG(INFO) << "The ready is:" << ready;
  686. KVMessage res_data;
  687. res_data.add_keys(key);
  688. res_data.add_values(ready);
  689. res->resize(res_data.ByteSizeLong());
  690. size_t dest_size = res_data.ByteSizeLong();
  691. size_t src_size = res_data.ByteSizeLong();
  692. int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
  693. if (ret != 0) {
  694. MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
  695. }
  696. }
  697. void ParameterServer::ServerHandler::HandleCheckReadyForPull(const DataPtr &data, size_t size, const VectorPtr &res) {
  698. MS_EXCEPTION_IF_NULL(res);
  699. KVMessage input;
  700. CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
  701. const Key &key = input.keys()[0];
  702. bool ready = ps_->ReadyForPull(key);
  703. KVMessage res_data;
  704. res_data.add_keys(key);
  705. res_data.add_values(ready);
  706. res->resize(res_data.ByteSizeLong());
  707. size_t dest_size = res_data.ByteSizeLong();
  708. size_t src_size = res_data.ByteSizeLong();
  709. int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
  710. if (ret != 0) {
  711. MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
  712. }
  713. }
  714. void ParameterServer::ServerHandler::HandleEmbeddingLookup(const DataPtr &data, size_t size, const VectorPtr &res) {
  715. MS_EXCEPTION_IF_NULL(res);
  716. EmbeddingTableLookup input;
  717. CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
  718. const Key &key = input.key();
  719. KVMessage res_data;
  720. std::vector<Key> keys = {input.keys().begin(), input.keys().end()};
  721. *res_data.mutable_keys() = {input.keys().begin(), input.keys().end()};
  722. ps_->DoEmbeddingLookup(key, keys, &res_data);
  723. res->resize(res_data.ByteSizeLong());
  724. size_t dest_size = res_data.ByteSizeLong();
  725. size_t src_size = res_data.ByteSizeLong();
  726. int ret = memcpy_s(res->data(), dest_size, res_data.SerializeAsString().data(), src_size);
  727. if (ret != 0) {
  728. MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
  729. }
  730. }
  731. void ParameterServer::ServerHandler::HandleUpdateEmbeddings(const DataPtr &data, size_t size, const VectorPtr &res) {
  732. std::unique_lock<std::mutex> lock(ps_->mutex());
  733. MS_EXCEPTION_IF_NULL(res);
  734. KVMessage input;
  735. CHECK_RETURN_TYPE(input.ParseFromArray(data.get(), SizeToInt(size)));
  736. const Key &key = input.keys()[0];
  737. const LookupIds &lookup_ids = {input.keys().begin() + 1, input.keys().end()};
  738. const Values &update_vals = {input.values().begin(), input.values().end()};
  739. ps_->UpdateEmbeddings(key, lookup_ids, update_vals);
  740. }
  741. void ParameterServer::ServerHandler::HandleFinalize(const DataPtr &, size_t, const VectorPtr &res) {
  742. MS_EXCEPTION_IF_NULL(res);
  743. ps_->Finalize();
  744. }
  745. } // namespace ps
  746. } // namespace mindspore