|
|
|
@@ -33,8 +33,7 @@ void ParameterServer::Run(const FuncGraphPtr &func_graph) { |
|
|
|
} |
|
|
|
Init(func_graph); |
|
|
|
server_node_->Start(); |
|
|
|
rank_id_ = server_node_->rank_id(); |
|
|
|
PSContext::instance()->SetPSRankId(rank_id_); |
|
|
|
PSContext::instance()->SetPSRankId(server_node_->rank_id()); |
|
|
|
thread_->join(); |
|
|
|
SyncEmbeddingTables(); |
|
|
|
MS_LOG(INFO) << "PServer finished updating models, starts finalizing..."; |
|
|
|
@@ -118,22 +117,22 @@ void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &value |
|
|
|
MS_EXCEPTION_IF_NULL(cnode); |
|
|
|
if (optim_name == kSparseAdam) { |
|
|
|
std::shared_ptr<PServerKernel> optimizer = |
|
|
|
std::make_shared<kernel::ps::SparseApplyAdamPSKernel>(rank_id_, pserver_num_, worker_num_); |
|
|
|
std::make_shared<kernel::ps::SparseApplyAdamPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_); |
|
|
|
optimizer->InitKernel(cnode, optim_inputs_shape_[key]); |
|
|
|
optimizers_[key] = optimizer; |
|
|
|
} else if (optim_name == kSparseLazyAdam) { |
|
|
|
std::shared_ptr<PServerKernel> optimizer = |
|
|
|
std::make_shared<kernel::ps::SparseApplyLazyAdamPSKernel>(rank_id_, pserver_num_, worker_num_); |
|
|
|
std::make_shared<kernel::ps::SparseApplyLazyAdamPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_); |
|
|
|
optimizer->InitKernel(cnode, optim_inputs_shape_[key]); |
|
|
|
optimizers_[key] = optimizer; |
|
|
|
} else if (optim_name == kApplyMomentum) { |
|
|
|
std::shared_ptr<PServerKernel> optimizer = |
|
|
|
std::make_shared<kernel::ps::ApplyMomentumPSKernel>(rank_id_, pserver_num_, worker_num_); |
|
|
|
std::make_shared<kernel::ps::ApplyMomentumPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_); |
|
|
|
optimizer->InitKernel(cnode, optim_inputs_shape_[key]); |
|
|
|
optimizers_[key] = optimizer; |
|
|
|
} else if (optim_name == kSparseFtrl) { |
|
|
|
std::shared_ptr<PServerKernel> optimizer = |
|
|
|
std::make_shared<kernel::ps::SparseApplyFtrlPSKernel>(rank_id_, pserver_num_, worker_num_); |
|
|
|
std::make_shared<kernel::ps::SparseApplyFtrlPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_); |
|
|
|
optimizer->InitKernel(cnode, optim_inputs_shape_[key]); |
|
|
|
optimizers_[key] = optimizer; |
|
|
|
} |
|
|
|
@@ -144,7 +143,7 @@ void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &value |
|
|
|
void ParameterServer::InitWeight(const Key &key, const WeightPtr &weight) { |
|
|
|
MS_EXCEPTION_IF_NULL(weight); |
|
|
|
if ((weights_.count(key) == 0) || (is_embedding_[key] && weights_.count(key) != 0)) { |
|
|
|
MS_LOG(INFO) << "Initializing weight for key " << key << ", server rank " << rank_id_; |
|
|
|
MS_LOG(INFO) << "Initializing weight for key " << key << ", server rank " << server_node_->rank_id(); |
|
|
|
weights_[key] = weight; |
|
|
|
tokens_[key] = 0; |
|
|
|
is_embedding_[key] = false; |
|
|
|
@@ -165,7 +164,7 @@ void ParameterServer::InitEmbeddingTable( |
|
|
|
MS_EXCEPTION_IF_NULL(shapes); |
|
|
|
if (weights_.count(key) == 0) { |
|
|
|
std::shared_ptr<PServerKernel> lookup = |
|
|
|
std::make_shared<kernel::ps::EmbeddingLookUpPSKernel>(rank_id_, pserver_num_, worker_num_); |
|
|
|
std::make_shared<kernel::ps::EmbeddingLookUpPSKernel>(server_node_->rank_id(), pserver_num_, worker_num_); |
|
|
|
lookup->InitKernel(shapes); |
|
|
|
embedding_lookup_ops_[key] = lookup; |
|
|
|
|
|
|
|
@@ -244,7 +243,7 @@ void ParameterServer::UpdateWeights() { |
|
|
|
[](std::shared_ptr<std::vector<size_t>> input_shapes) -> std::vector<size_t> { return *input_shapes; }); |
|
|
|
} |
|
|
|
optimizer->ReInit(shapes); |
|
|
|
optim_info->ComputeMean(shapes, worker_num_, pserver_num_, rank_id_); |
|
|
|
optim_info->ComputeMean(shapes, worker_num_, pserver_num_, server_node_->rank_id()); |
|
|
|
optimizer->Execute(inputs, workspaces, outputs); |
|
|
|
optim_info->Reset(); |
|
|
|
} |
|
|
|
@@ -296,7 +295,6 @@ WeightPtr ParameterServer::weight(const Key &key) { |
|
|
|
MS_LOG(EXCEPTION) << "Invalid weight key " << key; |
|
|
|
} |
|
|
|
WeightPtr weight_ptr = weights_[key]; |
|
|
|
MS_LOG(DEBUG) << "The weight ptr size is:" << weight_ptr->size(); |
|
|
|
MS_EXCEPTION_IF_NULL(weight_ptr); |
|
|
|
WeightPtr copy_weight_ptr = std::make_shared<std::vector<float>>(weight_ptr->size(), 0); |
|
|
|
MS_EXCEPTION_IF_NULL(copy_weight_ptr); |
|
|
|
|