Merge pull request !4704 from chengang/fix_ps_multi_workertags/v0.7.0-beta
| @@ -26,7 +26,8 @@ namespace kernel { | |||||
| namespace ps { | namespace ps { | ||||
| class ApplyMomentumPSKernel : public ApplyMomentumCPUKernel, public PServerKernel { | class ApplyMomentumPSKernel : public ApplyMomentumCPUKernel, public PServerKernel { | ||||
| public: | public: | ||||
| ApplyMomentumPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} | |||||
| ApplyMomentumPSKernel(size_t rank_id, size_t pserver_num, size_t worker_num) | |||||
| : PServerKernel(rank_id, pserver_num, worker_num) {} | |||||
| ~ApplyMomentumPSKernel() override = default; | ~ApplyMomentumPSKernel() override = default; | ||||
| bool Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | bool Execute(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||
| @@ -26,7 +26,8 @@ namespace kernel { | |||||
| namespace ps { | namespace ps { | ||||
| class EmbeddingLookUpPSKernel : public EmbeddingLookUpCPUKernel, public PServerKernel { | class EmbeddingLookUpPSKernel : public EmbeddingLookUpCPUKernel, public PServerKernel { | ||||
| public: | public: | ||||
| EmbeddingLookUpPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} | |||||
| EmbeddingLookUpPSKernel(size_t rank_id, size_t pserver_num, size_t worker_num) | |||||
| : PServerKernel(rank_id, pserver_num, worker_num) {} | |||||
| ~EmbeddingLookUpPSKernel() override = default; | ~EmbeddingLookUpPSKernel() override = default; | ||||
| void InitKernel(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &) override; | void InitKernel(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &) override; | ||||
| @@ -27,7 +27,8 @@ namespace ps { | |||||
| using mindspore::parallel::ps::Util; | using mindspore::parallel::ps::Util; | ||||
| class PServerKernel { | class PServerKernel { | ||||
| public: | public: | ||||
| PServerKernel(size_t rank_id, size_t pserver_num) : rank_id_(rank_id), pserver_num_(pserver_num) {} | |||||
| PServerKernel(size_t rank_id, size_t pserver_num, size_t worker_num) | |||||
| : rank_id_(rank_id), pserver_num_(pserver_num), worker_num_(worker_num) {} | |||||
| ~PServerKernel() = default; | ~PServerKernel() = default; | ||||
| PServerKernel(const PServerKernel &) = delete; | PServerKernel(const PServerKernel &) = delete; | ||||
| PServerKernel &operator=(const PServerKernel &) = delete; | PServerKernel &operator=(const PServerKernel &) = delete; | ||||
| @@ -50,6 +51,7 @@ class PServerKernel { | |||||
| size_t rank_id_; | size_t rank_id_; | ||||
| size_t pserver_num_; | size_t pserver_num_; | ||||
| size_t worker_num_; | |||||
| }; | }; | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace kernel | } // namespace kernel | ||||
| @@ -58,11 +58,11 @@ void SparseApplyAdamPSKernel::InitKernel( | |||||
| if (AnfAlgo::HasNodeAttr(USE_NESTEROV, cnode)) { | if (AnfAlgo::HasNodeAttr(USE_NESTEROV, cnode)) { | ||||
| use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(cnode, "use_nesterov"); | use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(cnode, "use_nesterov"); | ||||
| } | } | ||||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||||
| workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float)); | |||||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); | |||||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); | |||||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); | |||||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); | |||||
| workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); | |||||
| } | } | ||||
| void SparseApplyAdamPSKernel::ReInit(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) { | void SparseApplyAdamPSKernel::ReInit(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) { | ||||
| @@ -27,7 +27,8 @@ namespace ps { | |||||
| using mindspore::kernel::SparseApplyAdamCPUKernel; | using mindspore::kernel::SparseApplyAdamCPUKernel; | ||||
| class SparseApplyAdamPSKernel : public SparseApplyAdamCPUKernel, public PServerKernel { | class SparseApplyAdamPSKernel : public SparseApplyAdamCPUKernel, public PServerKernel { | ||||
| public: | public: | ||||
| SparseApplyAdamPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} | |||||
| SparseApplyAdamPSKernel(size_t rank_id, size_t pserver_num, size_t worker_num) | |||||
| : PServerKernel(rank_id, pserver_num, worker_num) {} | |||||
| ~SparseApplyAdamPSKernel() override = default; | ~SparseApplyAdamPSKernel() override = default; | ||||
| void InitKernel(const CNodePtr &cnode, | void InitKernel(const CNodePtr &cnode, | ||||
| @@ -62,10 +62,10 @@ void SparseApplyFtrlPSKernel::InitKernel( | |||||
| if (lr_power_ > 0) { | if (lr_power_ > 0) { | ||||
| MS_LOG(EXCEPTION) << "lr_power should be a non-positive scalar"; | MS_LOG(EXCEPTION) << "lr_power should be a non-positive scalar"; | ||||
| } | } | ||||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); | |||||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); | |||||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); | |||||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); | |||||
| } | } | ||||
| void SparseApplyFtrlPSKernel::ReInit(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) { | void SparseApplyFtrlPSKernel::ReInit(const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) { | ||||
| @@ -27,7 +27,8 @@ namespace ps { | |||||
| using mindspore::kernel::SparseApplyFtrlCPUKernel; | using mindspore::kernel::SparseApplyFtrlCPUKernel; | ||||
| class SparseApplyFtrlPSKernel : public SparseApplyFtrlCPUKernel, public PServerKernel { | class SparseApplyFtrlPSKernel : public SparseApplyFtrlCPUKernel, public PServerKernel { | ||||
| public: | public: | ||||
| SparseApplyFtrlPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} | |||||
| SparseApplyFtrlPSKernel(size_t rank_id, size_t pserver_num, size_t worker_num) | |||||
| : PServerKernel(rank_id, pserver_num, worker_num) {} | |||||
| ~SparseApplyFtrlPSKernel() override = default; | ~SparseApplyFtrlPSKernel() override = default; | ||||
| void InitKernel(const CNodePtr &cnode, | void InitKernel(const CNodePtr &cnode, | ||||
| @@ -58,11 +58,10 @@ void SparseApplyLazyAdamPSKernel::InitKernel( | |||||
| if (AnfAlgo::HasNodeAttr(USE_NESTEROV, cnode)) { | if (AnfAlgo::HasNodeAttr(USE_NESTEROV, cnode)) { | ||||
| use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(cnode, "use_nesterov"); | use_nesterov_ = AnfAlgo::GetNodeAttr<bool>(cnode, "use_nesterov"); | ||||
| } | } | ||||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float)); | |||||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int)); | |||||
| workspace_size_list_.emplace_back(var_first_dim_size_ * var_outer_dim_size_ * sizeof(float)); | |||||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); | |||||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); | |||||
| workspace_size_list_.emplace_back(indices_size_ * var_outer_dim_size_ * sizeof(float) * worker_num_); | |||||
| workspace_size_list_.emplace_back(indices_size_ * sizeof(int) * worker_num_); | |||||
| } | } | ||||
| void SparseApplyLazyAdamPSKernel::ReInit( | void SparseApplyLazyAdamPSKernel::ReInit( | ||||
| @@ -27,7 +27,8 @@ namespace ps { | |||||
| using mindspore::kernel::SparseApplyLazyAdamCPUKernel; | using mindspore::kernel::SparseApplyLazyAdamCPUKernel; | ||||
| class SparseApplyLazyAdamPSKernel : public SparseApplyLazyAdamCPUKernel, public PServerKernel { | class SparseApplyLazyAdamPSKernel : public SparseApplyLazyAdamCPUKernel, public PServerKernel { | ||||
| public: | public: | ||||
| SparseApplyLazyAdamPSKernel(size_t rank_id, size_t pserver_num) : PServerKernel(rank_id, pserver_num) {} | |||||
| SparseApplyLazyAdamPSKernel(size_t rank_id, size_t pserver_num, size_t worker_num) | |||||
| : PServerKernel(rank_id, pserver_num, worker_num) {} | |||||
| ~SparseApplyLazyAdamPSKernel() override = default; | ~SparseApplyLazyAdamPSKernel() override = default; | ||||
| void InitKernel(const CNodePtr &cnode, | void InitKernel(const CNodePtr &cnode, | ||||
| @@ -85,7 +85,7 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| grads_offset_ += incr_grad_size; | |||||
| grads_offset_ += lengths[grad_index]; | |||||
| gradient()->size += incr_grad_size; | gradient()->size += incr_grad_size; | ||||
| // Append indice data to the end | // Append indice data to the end | ||||
| @@ -103,7 +103,7 @@ void SparseOptimInfo::Accumulate(const Values &values, const Lengths &lengths) { | |||||
| if (ret2 != 0) { | if (ret2 != 0) { | ||||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; | MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret2 << ")"; | ||||
| } | } | ||||
| indices_offset_ += incr_indice_size; | |||||
| indices_offset_ += lengths[indices_index]; | |||||
| indices()->size += incr_indice_size; | indices()->size += incr_indice_size; | ||||
| } | } | ||||
| @@ -157,15 +157,58 @@ SparseAdamOptimInfo::SparseAdamOptimInfo(const AddressPtr &weight, const Address | |||||
| inputs_.push_back(epsilon); | inputs_.push_back(epsilon); | ||||
| inputs_.push_back(grad); | inputs_.push_back(grad); | ||||
| inputs_.push_back(indices); | inputs_.push_back(indices); | ||||
| grads_offset_ = 0; | |||||
| indices_offset_ = 0; | |||||
| grads_offset_ = grad->size / sizeof(float); | |||||
| indices_offset_ = indices->size / sizeof(int); | |||||
| } | } | ||||
| void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) { | void SparseAdamOptimInfo::Update(const Values &values, const Lengths &lens) { | ||||
| void *data_ptr = values.data(); | |||||
| AddressPtr beta1_power = inputs_[3]; | |||||
| size_t size = values.size() * sizeof(float); | |||||
| auto ret = memcpy_s(beta1_power->addr, size, data_ptr, size); | |||||
| float *data_ptr = values.data(); | |||||
| int offset = 0; | |||||
| AddressPtr &beta1_power = inputs_[3]; | |||||
| int size = lens[0]; | |||||
| int bytes = sizeof(float); | |||||
| auto ret = memcpy_s(beta1_power->addr, size * bytes, data_ptr + offset, size * bytes); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| offset += size; | |||||
| AddressPtr &beta2_power = inputs_[4]; | |||||
| size = lens[1]; | |||||
| ret = memcpy_s(beta2_power->addr, size * bytes, data_ptr + offset, size * bytes); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| offset += size; | |||||
| AddressPtr &lr = inputs_[5]; | |||||
| size = lens[2]; | |||||
| ret = memcpy_s(lr->addr, size * bytes, data_ptr + offset, size * bytes); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| offset += size; | |||||
| AddressPtr &beta1 = inputs_[6]; | |||||
| size = lens[3]; | |||||
| ret = memcpy_s(beta1->addr, size * bytes, data_ptr + offset, size * bytes); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| offset += size; | |||||
| AddressPtr &beta2 = inputs_[7]; | |||||
| size = lens[4]; | |||||
| ret = memcpy_s(beta2->addr, size * bytes, data_ptr + offset, size * bytes); | |||||
| if (ret != 0) { | |||||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | |||||
| } | |||||
| offset += size; | |||||
| AddressPtr &epsilon = inputs_[8]; | |||||
| size = lens[5]; | |||||
| ret = memcpy_s(epsilon->addr, size * bytes, data_ptr + offset, size * bytes); | |||||
| if (ret != 0) { | if (ret != 0) { | ||||
| MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; | ||||
| } | } | ||||
| @@ -188,8 +231,8 @@ SparseFtrlOptimInfo::SparseFtrlOptimInfo(const AddressPtr &weight, const Address | |||||
| inputs_.push_back(linear); | inputs_.push_back(linear); | ||||
| inputs_.push_back(grad); | inputs_.push_back(grad); | ||||
| inputs_.push_back(indices); | inputs_.push_back(indices); | ||||
| grads_offset_ = 0; | |||||
| indices_offset_ = 0; | |||||
| grads_offset_ = grad->size / sizeof(float); | |||||
| indices_offset_ = indices->size / sizeof(int); | |||||
| } | } | ||||
| const AddressPtr &SparseFtrlOptimInfo::gradient() { return inputs_[3]; } | const AddressPtr &SparseFtrlOptimInfo::gradient() { return inputs_[3]; } | ||||
| @@ -114,6 +114,7 @@ class ParameterServer { | |||||
| void InitGrad(const Key &key, const GradPtr &grad); | void InitGrad(const Key &key, const GradPtr &grad); | ||||
| void InitEmbeddingTable(const Key &key, | void InitEmbeddingTable(const Key &key, | ||||
| const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes); | const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes); | ||||
| bool HasWeight(const Key &key); | |||||
| void Finalize(); | void Finalize(); | ||||
| void UpdateWeights(); | void UpdateWeights(); | ||||
| void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); | void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); | ||||
| @@ -211,12 +212,14 @@ void ParameterServer<T>::ServerHandler::HandleInitWeights(const ::ps::KVMeta &re | |||||
| Key key = req_data.keys[i]; | Key key = req_data.keys[i]; | ||||
| size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i]; | size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i]; | ||||
| WeightPtr weight_ptr = std::make_shared<::ps::SArray<T>>(); | |||||
| weight_ptr->CopyFrom(data_ptr + pos, data_len); | |||||
| ps_->InitWeight(key, weight_ptr); | |||||
| if (!ps_->HasWeight(key)) { | |||||
| WeightPtr weight_ptr = std::make_shared<::ps::SArray<T>>(); | |||||
| weight_ptr->CopyFrom(data_ptr + pos, data_len); | |||||
| ps_->InitWeight(key, weight_ptr); | |||||
| GradPtr grad_ptr = std::make_shared<::ps::SArray<T>>(data_len, 0); | |||||
| ps_->InitGrad(key, grad_ptr); | |||||
| GradPtr grad_ptr = std::make_shared<::ps::SArray<T>>(data_len, 0); | |||||
| ps_->InitGrad(key, grad_ptr); | |||||
| } | |||||
| pos += data_len; | pos += data_len; | ||||
| } | } | ||||
| } | } | ||||
| @@ -379,22 +382,22 @@ void ParameterServer<T>::InitOptimInputsShape(const Keys &keys, const Values &va | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| if (optim_name == kSparseAdam) { | if (optim_name == kSparseAdam) { | ||||
| std::shared_ptr<PServerKernel> optimizer = | std::shared_ptr<PServerKernel> optimizer = | ||||
| std::make_shared<kernel::ps::SparseApplyAdamPSKernel>(rank_id_, pserver_num_); | |||||
| std::make_shared<kernel::ps::SparseApplyAdamPSKernel>(rank_id_, pserver_num_, worker_num_); | |||||
| optimizer->InitKernel(cnode, optim_inputs_shape_[key]); | optimizer->InitKernel(cnode, optim_inputs_shape_[key]); | ||||
| optimizers_[key] = optimizer; | optimizers_[key] = optimizer; | ||||
| } else if (optim_name == kSparseLazyAdam) { | } else if (optim_name == kSparseLazyAdam) { | ||||
| std::shared_ptr<PServerKernel> optimizer = | std::shared_ptr<PServerKernel> optimizer = | ||||
| std::make_shared<kernel::ps::SparseApplyLazyAdamPSKernel>(rank_id_, pserver_num_); | |||||
| std::make_shared<kernel::ps::SparseApplyLazyAdamPSKernel>(rank_id_, pserver_num_, worker_num_); | |||||
| optimizer->InitKernel(cnode, optim_inputs_shape_[key]); | optimizer->InitKernel(cnode, optim_inputs_shape_[key]); | ||||
| optimizers_[key] = optimizer; | optimizers_[key] = optimizer; | ||||
| } else if (optim_name == kApplyMomentum) { | } else if (optim_name == kApplyMomentum) { | ||||
| std::shared_ptr<PServerKernel> optimizer = | std::shared_ptr<PServerKernel> optimizer = | ||||
| std::make_shared<kernel::ps::ApplyMomentumPSKernel>(rank_id_, pserver_num_); | |||||
| std::make_shared<kernel::ps::ApplyMomentumPSKernel>(rank_id_, pserver_num_, worker_num_); | |||||
| optimizer->InitKernel(cnode, optim_inputs_shape_[key]); | optimizer->InitKernel(cnode, optim_inputs_shape_[key]); | ||||
| optimizers_[key] = optimizer; | optimizers_[key] = optimizer; | ||||
| } else if (optim_name == kSparseFtrl) { | } else if (optim_name == kSparseFtrl) { | ||||
| std::shared_ptr<PServerKernel> optimizer = | std::shared_ptr<PServerKernel> optimizer = | ||||
| std::make_shared<kernel::ps::SparseApplyFtrlPSKernel>(rank_id_, pserver_num_); | |||||
| std::make_shared<kernel::ps::SparseApplyFtrlPSKernel>(rank_id_, pserver_num_, worker_num_); | |||||
| optimizer->InitKernel(cnode, optim_inputs_shape_[key]); | optimizer->InitKernel(cnode, optim_inputs_shape_[key]); | ||||
| optimizers_[key] = optimizer; | optimizers_[key] = optimizer; | ||||
| } | } | ||||
| @@ -416,8 +419,8 @@ const CNodePtr ParameterServer<T>::GetCNode(const std::string &name) const { | |||||
| template <typename T> | template <typename T> | ||||
| void ParameterServer<T>::InitWeight(const Key &key, const WeightPtr &weight) { | void ParameterServer<T>::InitWeight(const Key &key, const WeightPtr &weight) { | ||||
| MS_LOG(INFO) << "Initializing weight for key " << key << ", server rank " << rank_id_; | |||||
| if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) { | if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) { | ||||
| MS_LOG(INFO) << "Initializing weight for key " << key << ", server rank " << rank_id_; | |||||
| weights_[key] = weight; | weights_[key] = weight; | ||||
| tokens_[key] = 0; | tokens_[key] = 0; | ||||
| is_embedding_[key] = false; | is_embedding_[key] = false; | ||||
| @@ -435,29 +438,37 @@ void ParameterServer<T>::InitGrad(const Key &key, const GradPtr &grad) { | |||||
| template <typename T> | template <typename T> | ||||
| void ParameterServer<T>::InitEmbeddingTable( | void ParameterServer<T>::InitEmbeddingTable( | ||||
| const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) { | const Key &key, const std::shared_ptr<std::vector<std::shared_ptr<std::vector<size_t>>>> &shapes) { | ||||
| std::shared_ptr<PServerKernel> lookup = std::make_shared<kernel::ps::EmbeddingLookUpPSKernel>(rank_id_, pserver_num_); | |||||
| lookup->InitKernel(shapes); | |||||
| embedding_lookup_ops_[key] = lookup; | |||||
| // Init embedding weight | |||||
| const std::vector<size_t> &input_shapes = lookup->input_sizes(); | |||||
| size_t total_dims = 1; | |||||
| for (auto shape : input_shapes) { | |||||
| total_dims *= shape; | |||||
| } | |||||
| if (weights_.count(key) == 0) { | |||||
| std::shared_ptr<PServerKernel> lookup = | |||||
| std::make_shared<kernel::ps::EmbeddingLookUpPSKernel>(rank_id_, pserver_num_, worker_num_); | |||||
| lookup->InitKernel(shapes); | |||||
| embedding_lookup_ops_[key] = lookup; | |||||
| // Init embedding weight | |||||
| const std::vector<size_t> &input_shapes = lookup->input_sizes(); | |||||
| size_t total_dims = 1; | |||||
| for (auto shape : input_shapes) { | |||||
| total_dims *= shape; | |||||
| } | |||||
| WeightPtr embedding = std::make_shared<Weight>(total_dims, 0); | |||||
| T *embedding_data = embedding->data(); | |||||
| std::default_random_engine engine; | |||||
| std::normal_distribution<float> random(0, 0.01); | |||||
| for (size_t i = 0; i < total_dims; i++) { | |||||
| embedding_data[i] = random(engine); | |||||
| } | |||||
| weights_[key] = embedding; | |||||
| tokens_[key] = 0; | |||||
| is_embedding_[key] = true; | |||||
| WeightPtr embedding = std::make_shared<Weight>(total_dims, 0); | |||||
| T *embedding_data = embedding->data(); | |||||
| std::default_random_engine engine; | |||||
| std::normal_distribution<float> random(0, 0.01); | |||||
| for (size_t i = 0; i < total_dims; i++) { | |||||
| embedding_data[i] = random(engine); | |||||
| grads_accum_counter_[key] = 0; | |||||
| } | } | ||||
| weights_[key] = embedding; | |||||
| tokens_[key] = 0; | |||||
| is_embedding_[key] = true; | |||||
| } | |||||
| grads_accum_counter_[key] = 0; | |||||
| template <typename T> | |||||
| bool ParameterServer<T>::HasWeight(const Key &key) { | |||||
| return (weights_.count(key) > 0 && !is_embedding_.count(key)); | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||