From 40a65e3878f45570793520ed93c8933a72a4587e Mon Sep 17 00:00:00 2001 From: cristoval Date: Wed, 19 Aug 2020 00:50:43 +0800 Subject: [PATCH] bugfix for server crush --- .../cpu/ps/apply_momentum_ps_kernel.h | 3 +- .../cpu/ps/embedding_look_up_ps_kernel.h | 3 +- .../kernel_compiler/cpu/ps/pserver_kernel.h | 4 +- .../cpu/ps/sparse_apply_adam_ps_kernel.cc | 10 +-- .../cpu/ps/sparse_apply_adam_ps_kernel.h | 3 +- .../cpu/ps/sparse_apply_ftrl_ps_kernel.cc | 8 +-- .../cpu/ps/sparse_apply_ftrl_ps_kernel.h | 3 +- .../ps/sparse_apply_lazy_adam_ps_kernel.cc | 9 ++- .../cpu/ps/sparse_apply_lazy_adam_ps_kernel.h | 3 +- .../frontend/parallel/ps/optimizer_info.cc | 63 +++++++++++++--- .../frontend/parallel/ps/parameter_server.h | 71 +++++++++++-------- 11 files changed, 120 insertions(+), 60 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h index 18b040cc86..8918299540 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h @@ -26,7 +26,8 @@ namespace kernel { namespace ps { class ApplyMomentumPSKernel : public ApplyMomentumCPUKernel, public PServerKernel { public: - ApplyMomentumPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} + ApplyMomentumPSKernel(size_t rank_id, size_t pserver_num, size_t worker_num) + : PServerKernel(rank_id, pserver_num, worker_num) {} ~ApplyMomentumPSKernel() override = default; bool Execute(const std::vector &inputs, const std::vector &workspace, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h index 987de740d8..815b2b6f77 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h @@ -26,7 +26,8 @@ namespace kernel { namespace ps { class EmbeddingLookUpPSKernel : public EmbeddingLookUpCPUKernel, public PServerKernel { public: - EmbeddingLookUpPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} + EmbeddingLookUpPSKernel(size_t rank_id, size_t pserver_num, size_t worker_num) + : PServerKernel(rank_id, pserver_num, worker_num) {} ~EmbeddingLookUpPSKernel() override = default; void InitKernel(const std::shared_ptr>>> &) override; diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h index 158b890929..af0a1fe342 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/pserver_kernel.h @@ -27,7 +27,8 @@ namespace ps { using mindspore::parallel::ps::Util; class PServerKernel { public: - PServerKernel(size_t rank_id, size_t pserver_num) : rank_id_(rank_id), pserver_num_(pserver_num) {} + PServerKernel(size_t rank_id, size_t pserver_num, size_t worker_num) + : rank_id_(rank_id), pserver_num_(pserver_num), worker_num_(worker_num) {} ~PServerKernel() = default; PServerKernel(const PServerKernel &) = delete; PServerKernel &operator=(const PServerKernel &) = delete; @@ -50,6 +51,7 @@ class PServerKernel { size_t rank_id_; size_t pserver_num_; + size_t worker_num_; }; } // namespace ps } // namespace kernel diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc index 222a980fc6..4ae5745133 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.cc @@ -58,11 +58,11 @@ void SparseApplyAdamPSKernel::InitKernel( if (AnfAlgo::HasNodeAttr(USE_NESTEROV, cnode)) { use_nesterov_ = AnfAlgo::GetNodeAttr(cnode, "use_nesterov"); } - workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); - workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); - workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); - workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); - workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); + workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); } void SparseApplyAdamPSKernel::ReInit(const std::shared_ptr>>> &shapes) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h index bd3e021a69..a28e62abd6 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h @@ -27,7 +27,8 @@ namespace ps { using mindspore::kernel::SparseApplyAdamCPUKernel; class SparseApplyAdamPSKernel : public SparseApplyAdamCPUKernel, public PServerKernel { public: - SparseApplyAdamPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} + SparseApplyAdamPSKernel(size_t rank_id, size_t pserver_num, size_t worker_num) + : PServerKernel(rank_id, pserver_num, worker_num) {} ~SparseApplyAdamPSKernel() override = default; void InitKernel(const CNodePtr &cnode, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc index afd676382f..b5f5f0259c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.cc @@ -62,10 +62,10 @@ void SparseApplyFtrlPSKernel::InitKernel( if (lr_power_ > 0) { MS_LOG(EXCEPTION) << "lr_power should be a non-positive scalar"; } - workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); - workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); - workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); - workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); } void SparseApplyFtrlPSKernel::ReInit(const std::shared_ptr>>> &shapes) { diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h index 3a5dfc738e..6d37dd4495 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h @@ -27,7 +27,8 @@ namespace ps { using mindspore::kernel::SparseApplyFtrlCPUKernel; class SparseApplyFtrlPSKernel : public SparseApplyFtrlCPUKernel, public PServerKernel { public: - SparseApplyFtrlPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} + SparseApplyFtrlPSKernel(size_t rank_id, size_t pserver_num, size_t worker_num) + : PServerKernel(rank_id, pserver_num, worker_num) {} ~SparseApplyFtrlPSKernel() override = default; void InitKernel(const CNodePtr &cnode, diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc index 03949b3685..0c5c5947a8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.cc @@ -58,11 +58,10 @@ void SparseApplyLazyAdamPSKernel::InitKernel( if (AnfAlgo::HasNodeAttr(USE_NESTEROV, cnode)) { use_nesterov_ = AnfAlgo::GetNodeAttr(cnode, "use_nesterov"); } - workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); - workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); - workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); - workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); - workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float)); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); + workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); + workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); } void SparseApplyLazyAdamPSKernel::ReInit( diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h index 595f2ab6a3..070f42b96c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h @@ -27,7 +27,8 @@ namespace ps { using mindspore::kernel::SparseApplyLazyAdamCPUKernel; class SparseApplyLazyAdamPSKernel : public SparseApplyLazyAdamCPUKernel, public PServerKernel { public: - SparseApplyLazyAdamPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} + SparseApplyLazyAdamPSKernel(size_t rank_id, size_t pserver_num, size_t worker_num) + : PServerKernel(rank_id, pserver_num, worker_num) {} ~SparseApplyLazyAdamPSKernel() override = default; void InitKernel(const CNodePtr &cnode, diff --git a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc index 5801b241e4..dd8fedee6d 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ps/optimizer_info.cc @@ -85,7 +85,7 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; } - grads_offset_ += incr_grad_size; + grads_offset_ += lengths[grad_index]; gradient()->size += incr_grad_size; // Append indice data to the end @@ -103,7 +103,7 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { if (ret2 != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; } - indices_offset_ += incr_indice_size; + indices_offset_ += lengths[indices_index]; indices()->size += incr_indice_size; } @@ -157,15 +157,58 @@ SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const Address inputs_.push_back(epsilon); inputs_.push_back(grad); inputs_.push_back(indices); - grads_offset_ = 0; - indices_offset_ = 0; + grads_offset_ = grad->size / sizeof(float); + indices_offset_ = indices->size / sizeof(int); } void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) { - void *data_ptr = values.data(); - AddressPtr beta1_power = inputs_[3]; - size_t size = values.size() * sizeof(float); - auto ret = memcpy_s(beta1_power->addr, size, data_ptr, size); + float *data_ptr = values.data(); + int offset = 0; + + AddressPtr &beta1_power = inputs_[3]; + int size = lens[0]; + int bytes = sizeof(float); + auto ret = memcpy_s(beta1_power->addr, size * bytes, data_ptr + offset, size * bytes); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + + offset += size; + AddressPtr &beta2_power = inputs_[4]; + size = lens[1]; + ret = memcpy_s(beta2_power->addr, size * bytes, data_ptr + offset, size * bytes); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + + offset += size; + AddressPtr &lr = inputs_[5]; + size = lens[2]; + ret = memcpy_s(lr->addr, size * bytes, data_ptr + offset, size * bytes); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + + offset += size; + AddressPtr &beta1 = inputs_[6]; + size = lens[3]; + ret = memcpy_s(beta1->addr, size * bytes, data_ptr + offset, size * bytes); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + + offset += size; + AddressPtr &beta2 = inputs_[7]; + size = lens[4]; + ret = memcpy_s(beta2->addr, size * bytes, data_ptr + offset, size * bytes); + if (ret != 0) { + MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; + } + + offset += size; + AddressPtr &epsilon = inputs_[8]; + size = lens[5]; + ret = memcpy_s(epsilon->addr, size * bytes, data_ptr + offset, size * bytes); if (ret != 0) { MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; } @@ -188,8 +231,8 @@ SparseFtrlOptimInfo::SparseFtrlOptimInfo(const AddressPtr &weight, const Address inputs_.push_back(linear); inputs_.push_back(grad); inputs_.push_back(indices); - grads_offset_ = 0; - indices_offset_ = 0; + grads_offset_ = grad->size / sizeof(float); + indices_offset_ = indices->size / sizeof(int); } const AddressPtr &SparseFtrlOptimInfo::gradient() { return inputs_[3]; } diff --git a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h index 60e88a9c10..984ff92a0c 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h +++ b/mindspore/ccsrc/frontend/parallel/ps/parameter_server.h @@ -114,6 +114,7 @@ class ParameterServer { void InitGrad(const Key &key, const GradPtr &grad); void InitEmbeddingTable(const Key &key, const std::shared_ptr>>> &shapes); + bool HasWeight(const Key &key); void Finalize(); void UpdateWeights(); void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); @@ -211,12 +212,14 @@ void ParameterServer::ServerHandler::HandleInitWeights(const ::ps::KVMeta &re Key key = req_data.keys[i]; size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i]; - WeightPtr weight_ptr = std::make_shared<::ps::SArray>(); - weight_ptr->CopyFrom(data_ptr + pos, data_len); - ps_->InitWeight(key, weight_ptr); + if (!ps_->HasWeight(key)) { + WeightPtr weight_ptr = std::make_shared<::ps::SArray>(); + weight_ptr->CopyFrom(data_ptr + pos, data_len); + ps_->InitWeight(key, weight_ptr); - GradPtr grad_ptr = std::make_shared<::ps::SArray>(data_len, 0); - ps_->InitGrad(key, grad_ptr); + GradPtr grad_ptr = std::make_shared<::ps::SArray>(data_len, 0); + ps_->InitGrad(key, grad_ptr); + } pos += data_len; } } @@ -379,22 +382,22 @@ void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &va MS_EXCEPTION_IF_NULL(cnode); if (optim_name == kSparseAdam) { std::shared_ptr optimizer = - std::make_shared(rank_id_, pserver_num_); + std::make_shared(rank_id_, pserver_num_, worker_num_); optimizer->InitKernel(cnode, optim_inputs_shape_[key]); optimizers_[key] = optimizer; } else if (optim_name == kSparseLazyAdam) { std::shared_ptr optimizer = - std::make_shared(rank_id_, pserver_num_); + std::make_shared(rank_id_, pserver_num_, worker_num_); optimizer->InitKernel(cnode, optim_inputs_shape_[key]); optimizers_[key] = optimizer; } else if (optim_name == kApplyMomentum) { std::shared_ptr optimizer = - std::make_shared(rank_id_, pserver_num_); + std::make_shared(rank_id_, pserver_num_, worker_num_); optimizer->InitKernel(cnode, optim_inputs_shape_[key]); optimizers_[key] = optimizer; } else if (optim_name == kSparseFtrl) { std::shared_ptr optimizer = - std::make_shared(rank_id_, pserver_num_); + std::make_shared(rank_id_, pserver_num_, worker_num_); optimizer->InitKernel(cnode, optim_inputs_shape_[key]); optimizers_[key] = optimizer; } @@ -416,8 +419,8 @@ const CNodePtr ParameterServer::GetCNode(const std::string &name) const { template void ParameterServer::InitWeight(const Key &key, const WeightPtr &weight) { - MS_LOG(INFO) << "Initializing weight for key " << key << ", server rank " << rank_id_; if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) { + MS_LOG(INFO) << "Initializing weight for key " << key << ", server rank " << rank_id_; weights_[key] = weight; tokens_[key] = 0; is_embedding_[key] = false; @@ -435,29 +438,37 @@ void ParameterServer::InitGrad(const Key &key, const GradPtr &grad) { template void ParameterServer::InitEmbeddingTable( const Key &key, const std::shared_ptr>>> &shapes) { - std::shared_ptr lookup = std::make_shared(rank_id_, pserver_num_); - lookup->InitKernel(shapes); - embedding_lookup_ops_[key] = lookup; - - // Init embedding weight - const std::vector &input_shapes = lookup->input_sizes(); - size_t total_dims = 1; - for (auto shape : input_shapes) { - total_dims *= shape; - } + if (weights_.count(key) == 0) { + std::shared_ptr lookup = + std::make_shared(rank_id_, pserver_num_, worker_num_); + lookup->InitKernel(shapes); + embedding_lookup_ops_[key] = lookup; + + // Init embedding weight + const std::vector &input_shapes = lookup->input_sizes(); + size_t total_dims = 1; + for (auto shape : input_shapes) { + total_dims *= shape; + } + + WeightPtr embedding = std::make_shared(total_dims, 0); + T *embedding_data = embedding->data(); + std::default_random_engine engine; + std::normal_distribution random(0, 0.01); + for (size_t i = 0; i < total_dims; i++) { + embedding_data[i] = random(engine); + } + weights_[key] = embedding; + tokens_[key] = 0; + is_embedding_[key] = true; - WeightPtr embedding = std::make_shared(total_dims, 0); - T *embedding_data = embedding->data(); - std::default_random_engine engine; - std::normal_distribution random(0, 0.01); - for (size_t i = 0; i < total_dims; i++) { - embedding_data[i] = random(engine); + grads_accum_counter_[key] = 0; } - weights_[key] = embedding; - tokens_[key] = 0; - is_embedding_[key] = true; +} - grads_accum_counter_[key] = 0; +template +bool ParameterServer::HasWeight(const Key &key) { + return (weights_.count(key) > 0 && !is_embedding_.count(key)); } template