Merge pull request !4704 from chengang/fix_ps_multi_workertags/v0.7.0-beta
| @@ -26,7 +26,8 @@ namespace kernel { | |||
| namespace ps { | |||
| class ApplyMomentumPSKernel : public ApplyMomentumCPUKernel, public PServerKernel { | |||
| public: | |||
| ApplyMomentumPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} | |||
| ApplyMomentumPSKernel(size_t rank_id, size_t pserver_num, size_t worker_num) | |||
| : PServerKernel(rank_id, pserver_num, worker_num) {} | |||
| ~ApplyMomentumPSKernel() override = default; | |||
| bool Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| @@ -26,7 +26,8 @@ namespace kernel { | |||
| namespace ps { | |||
| class EmbeddingLookUpPSKernel : public EmbeddingLookUpCPUKernel, public PServerKernel { | |||
| public: | |||
| EmbeddingLookUpPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} | |||
| EmbeddingLookUpPSKernel(size_t rank_id, size_t pserver_num, size_t worker_num) | |||
| : PServerKernel(rank_id, pserver_num, worker_num) {} | |||
| ~EmbeddingLookUpPSKernel() override = default; | |||
| void InitKernel(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &) override; | |||
| @@ -27,7 +27,8 @@ namespace ps { | |||
| using mindspore::parallel::ps::Util; | |||
| class PServerKernel { | |||
| public: | |||
| PServerKernel(size_t rank_id, size_t pserver_num) : rank_id_(rank_id), pserver_num_(pserver_num) {} | |||
| PServerKernel(size_t rank_id, size_t pserver_num, size_t worker_num) | |||
| : rank_id_(rank_id), pserver_num_(pserver_num), worker_num_(worker_num) {} | |||
| ~PServerKernel() = default; | |||
| PServerKernel(const PServerKernel &) = delete; | |||
| PServerKernel &operator=(const PServerKernel &) = delete; | |||
| @@ -50,6 +51,7 @@ class PServerKernel { | |||
| size_t rank_id_; | |||
| size_t pserver_num_; | |||
| size_t worker_num_; | |||
| }; | |||
| } // namespace ps | |||
| } // namespace kernel | |||
| @@ -58,11 +58,11 @@ void SparseApplyAdamPSKernel::InitKernel( | |||
| if (AnfAlgo::HasNodeAttr(USE_NESTEROV, cnode)) { | |||
| use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(cnode, "use_nesterov"); | |||
| } | |||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||
| workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float)); | |||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); | |||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); | |||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); | |||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); | |||
| workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); | |||
| } | |||
| void SparseApplyAdamPSKernel::ReInit(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) { | |||
| @@ -27,7 +27,8 @@ namespace ps { | |||
| using mindspore::kernel::SparseApplyAdamCPUKernel; | |||
| class SparseApplyAdamPSKernel : public SparseApplyAdamCPUKernel, public PServerKernel { | |||
| public: | |||
| SparseApplyAdamPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} | |||
| SparseApplyAdamPSKernel(size_t rank_id, size_t pserver_num, size_t worker_num) | |||
| : PServerKernel(rank_id, pserver_num, worker_num) {} | |||
| ~SparseApplyAdamPSKernel() override = default; | |||
| void InitKernel(const CNodePtr &cnode, | |||
| @@ -62,10 +62,10 @@ void SparseApplyFtrlPSKernel::InitKernel( | |||
| if (lr_power_ > 0) { | |||
| MS_LOG(EXCEPTION) << "lr_power should be a non-positive scalar"; | |||
| } | |||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); | |||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); | |||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); | |||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); | |||
| } | |||
| void SparseApplyFtrlPSKernel::ReInit(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) { | |||
| @@ -27,7 +27,8 @@ namespace ps { | |||
| using mindspore::kernel::SparseApplyFtrlCPUKernel; | |||
| class SparseApplyFtrlPSKernel : public SparseApplyFtrlCPUKernel, public PServerKernel { | |||
| public: | |||
| SparseApplyFtrlPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} | |||
| SparseApplyFtrlPSKernel(size_t rank_id, size_t pserver_num, size_t worker_num) | |||
| : PServerKernel(rank_id, pserver_num, worker_num) {} | |||
| ~SparseApplyFtrlPSKernel() override = default; | |||
| void InitKernel(const CNodePtr &cnode, | |||
| @@ -58,11 +58,10 @@ void SparseApplyLazyAdamPSKernel::InitKernel( | |||
| if (AnfAlgo::HasNodeAttr(USE_NESTEROV, cnode)) { | |||
| use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(cnode, "use_nesterov"); | |||
| } | |||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||
| workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float)); | |||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); | |||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); | |||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); | |||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); | |||
| } | |||
| void SparseApplyLazyAdamPSKernel::ReInit( | |||
| @@ -27,7 +27,8 @@ namespace ps { | |||
| using mindspore::kernel::SparseApplyLazyAdamCPUKernel; | |||
| class SparseApplyLazyAdamPSKernel : public SparseApplyLazyAdamCPUKernel, public PServerKernel { | |||
| public: | |||
| SparseApplyLazyAdamPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} | |||
| SparseApplyLazyAdamPSKernel(size_t rank_id, size_t pserver_num, size_t worker_num) | |||
| : PServerKernel(rank_id, pserver_num, worker_num) {} | |||
| ~SparseApplyLazyAdamPSKernel() override = default; | |||
| void InitKernel(const CNodePtr &cnode, | |||
| @@ -85,7 +85,7 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| grads_offset_ += incr_grad_size; | |||
| grads_offset_ += lengths[grad_index]; | |||
| gradient()->size += incr_grad_size; | |||
| // Append indice data to the end | |||
| @@ -103,7 +103,7 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { | |||
| if (ret2 != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; | |||
| } | |||
| indices_offset_ += incr_indice_size; | |||
| indices_offset_ += lengths[indices_index]; | |||
| indices()->size += incr_indice_size; | |||
| } | |||
| @@ -157,15 +157,58 @@ SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const Address | |||
| inputs_.push_back(epsilon); | |||
| inputs_.push_back(grad); | |||
| inputs_.push_back(indices); | |||
| grads_offset_ = 0; | |||
| indices_offset_ = 0; | |||
| grads_offset_ = grad->size / sizeof(float); | |||
| indices_offset_ = indices->size / sizeof(int); | |||
| } | |||
| void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) { | |||
| void *data_ptr = values.data(); | |||
| AddressPtr beta1_power = inputs_[3]; | |||
| size_t size = values.size() * sizeof(float); | |||
| auto ret = memcpy_s(beta1_power->addr, size, data_ptr, size); | |||
| float *data_ptr = values.data(); | |||
| int offset = 0; | |||
| AddressPtr &beta1_power = inputs_[3]; | |||
| int size = lens[0]; | |||
| int bytes = sizeof(float); | |||
| auto ret = memcpy_s(beta1_power->addr, size * bytes, data_ptr + offset, size * bytes); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| offset += size; | |||
| AddressPtr &beta2_power = inputs_[4]; | |||
| size = lens[1]; | |||
| ret = memcpy_s(beta2_power->addr, size * bytes, data_ptr + offset, size * bytes); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| offset += size; | |||
| AddressPtr &lr = inputs_[5]; | |||
| size = lens[2]; | |||
| ret = memcpy_s(lr->addr, size * bytes, data_ptr + offset, size * bytes); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| offset += size; | |||
| AddressPtr &beta1 = inputs_[6]; | |||
| size = lens[3]; | |||
| ret = memcpy_s(beta1->addr, size * bytes, data_ptr + offset, size * bytes); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| offset += size; | |||
| AddressPtr &beta2 = inputs_[7]; | |||
| size = lens[4]; | |||
| ret = memcpy_s(beta2->addr, size * bytes, data_ptr + offset, size * bytes); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| offset += size; | |||
| AddressPtr &epsilon = inputs_[8]; | |||
| size = lens[5]; | |||
| ret = memcpy_s(epsilon->addr, size * bytes, data_ptr + offset, size * bytes); | |||
| if (ret != 0) { | |||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||
| } | |||
| @@ -188,8 +231,8 @@ SparseFtrlOptimInfo::SparseFtrlOptimInfo(const AddressPtr &weight, const Address | |||
| inputs_.push_back(linear); | |||
| inputs_.push_back(grad); | |||
| inputs_.push_back(indices); | |||
| grads_offset_ = 0; | |||
| indices_offset_ = 0; | |||
| grads_offset_ = grad->size / sizeof(float); | |||
| indices_offset_ = indices->size / sizeof(int); | |||
| } | |||
| const AddressPtr &SparseFtrlOptimInfo::gradient() { return inputs_[3]; } | |||
| @@ -114,6 +114,7 @@ class ParameterServer { | |||
| void InitGrad(const Key &key, const GradPtr &grad); | |||
| void InitEmbeddingTable(const Key &key, | |||
| const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes); | |||
| bool HasWeight(const Key &key); | |||
| void Finalize(); | |||
| void UpdateWeights(); | |||
| void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); | |||
| @@ -211,12 +212,14 @@ void ParameterServer<T>::ServerHandler::HandleInitWeights(const ::ps::KVMeta &re | |||
| Key key = req_data.keys[i]; | |||
| size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i]; | |||
| WeightPtr weight_ptr = std::make_shared<::ps::SArray<T>>(); | |||
| weight_ptr->CopyFrom(data_ptr + pos, data_len); | |||
| ps_->InitWeight(key, weight_ptr); | |||
| if (!ps_->HasWeight(key)) { | |||
| WeightPtr weight_ptr = std::make_shared<::ps::SArray<T>>(); | |||
| weight_ptr->CopyFrom(data_ptr + pos, data_len); | |||
| ps_->InitWeight(key, weight_ptr); | |||
| GradPtr grad_ptr = std::make_shared<::ps::SArray<T>>(data_len, 0); | |||
| ps_->InitGrad(key, grad_ptr); | |||
| GradPtr grad_ptr = std::make_shared<::ps::SArray<T>>(data_len, 0); | |||
| ps_->InitGrad(key, grad_ptr); | |||
| } | |||
| pos += data_len; | |||
| } | |||
| } | |||
| @@ -379,22 +382,22 @@ void ParameterServer<T>::InitOptimInputsShape(const Keys &keys, const Values &va | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| if (optim_name == kSparseAdam) { | |||
| std::shared_ptr<PServerKernel> optimizer = | |||
| std::make_shared<kernel::ps::SparseApplyAdamPSKernel>(rank_id_, pserver_num_); | |||
| std::make_shared<kernel::ps::SparseApplyAdamPSKernel>(rank_id_, pserver_num_, worker_num_); | |||
| optimizer->InitKernel(cnode, optim_inputs_shape_[key]); | |||
| optimizers_[key] = optimizer; | |||
| } else if (optim_name == kSparseLazyAdam) { | |||
| std::shared_ptr<PServerKernel> optimizer = | |||
| std::make_shared<kernel::ps::SparseApplyLazyAdamPSKernel>(rank_id_, pserver_num_); | |||
| std::make_shared<kernel::ps::SparseApplyLazyAdamPSKernel>(rank_id_, pserver_num_, worker_num_); | |||
| optimizer->InitKernel(cnode, optim_inputs_shape_[key]); | |||
| optimizers_[key] = optimizer; | |||
| } else if (optim_name == kApplyMomentum) { | |||
| std::shared_ptr<PServerKernel> optimizer = | |||
| std::make_shared<kernel::ps::ApplyMomentumPSKernel>(rank_id_, pserver_num_); | |||
| std::make_shared<kernel::ps::ApplyMomentumPSKernel>(rank_id_, pserver_num_, worker_num_); | |||
| optimizer->InitKernel(cnode, optim_inputs_shape_[key]); | |||
| optimizers_[key] = optimizer; | |||
| } else if (optim_name == kSparseFtrl) { | |||
| std::shared_ptr<PServerKernel> optimizer = | |||
| std::make_shared<kernel::ps::SparseApplyFtrlPSKernel>(rank_id_, pserver_num_); | |||
| std::make_shared<kernel::ps::SparseApplyFtrlPSKernel>(rank_id_, pserver_num_, worker_num_); | |||
| optimizer->InitKernel(cnode, optim_inputs_shape_[key]); | |||
| optimizers_[key] = optimizer; | |||
| } | |||
| @@ -416,8 +419,8 @@ const CNodePtr ParameterServer<T>::GetCNode(const std::string &name) const { | |||
| template <typename T> | |||
| void ParameterServer<T>::InitWeight(const Key &key, const WeightPtr &weight) { | |||
| MS_LOG(INFO) << "Initializing weight for key " << key << ", server rank " << rank_id_; | |||
| if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) { | |||
| MS_LOG(INFO) << "Initializing weight for key " << key << ", server rank " << rank_id_; | |||
| weights_[key] = weight; | |||
| tokens_[key] = 0; | |||
| is_embedding_[key] = false; | |||
| @@ -435,29 +438,37 @@ void ParameterServer<T>::InitGrad(const Key &key, const GradPtr &grad) { | |||
| template <typename T> | |||
| void ParameterServer<T>::InitEmbeddingTable( | |||
| const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) { | |||
| std::shared_ptr<PServerKernel> lookup = std::make_shared<kernel::ps::EmbeddingLookUpPSKernel>(rank_id_, pserver_num_); | |||
| lookup->InitKernel(shapes); | |||
| embedding_lookup_ops_[key] = lookup; | |||
| // Init embedding weight | |||
| const std::vector<size_t> &input_shapes = lookup->input_sizes(); | |||
| size_t total_dims = 1; | |||
| for (auto shape : input_shapes) { | |||
| total_dims *= shape; | |||
| } | |||
| if (weights_.count(key) == 0) { | |||
| std::shared_ptr<PServerKernel> lookup = | |||
| std::make_shared<kernel::ps::EmbeddingLookUpPSKernel>(rank_id_, pserver_num_, worker_num_); | |||
| lookup->InitKernel(shapes); | |||
| embedding_lookup_ops_[key] = lookup; | |||
| // Init embedding weight | |||
| const std::vector<size_t> &input_shapes = lookup->input_sizes(); | |||
| size_t total_dims = 1; | |||
| for (auto shape : input_shapes) { | |||
| total_dims *= shape; | |||
| } | |||
| WeightPtr embedding = std::make_shared<Weight>(total_dims, 0); | |||
| T *embedding_data = embedding->data(); | |||
| std::default_random_engine engine; | |||
| std::normal_distribution<float> random(0, 0.01); | |||
| for (size_t i = 0; i < total_dims; i++) { | |||
| embedding_data[i] = random(engine); | |||
| } | |||
| weights_[key] = embedding; | |||
| tokens_[key] = 0; | |||
| is_embedding_[key] = true; | |||
| WeightPtr embedding = std::make_shared<Weight>(total_dims, 0); | |||
| T *embedding_data = embedding->data(); | |||
| std::default_random_engine engine; | |||
| std::normal_distribution<float> random(0, 0.01); | |||
| for (size_t i = 0; i < total_dims; i++) { | |||
| embedding_data[i] = random(engine); | |||
| grads_accum_counter_[key] = 0; | |||
| } | |||
| weights_[key] = embedding; | |||
| tokens_[key] = 0; | |||
| is_embedding_[key] = true; | |||
| } | |||
| grads_accum_counter_[key] = 0; | |||
| template <typename T> | |||
| bool ParameterServer<T>::HasWeight(const Key &key) { | |||
| return (weights_.count(key) > 0 && !is_embedding_.count(key)); | |||
| } | |||
| template <typename T> | |||