/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ #define MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_ #include #include #include #include #include #include #include #include #include #include #include #include "ir/func_graph.h" #include "session/session_basic.h" #include "session/kernel_graph.h" #include "session/anf_runtime_algorithm.h" #include "session/session_factory.h" #include "parallel/ps/common.h" #include "parallel/ps/optimizer_info.h" #include "parallel/ps/optimizer_info_builder.h" #include "parallel/ps/util.h" #include "device/cpu/kernel_select_cpu.h" #include "utils/context/ms_context.h" #include "kernel/kernel.h" #include "kernel/ps/pserver_kernel.h" #include "kernel/cpu/cpu_kernel_factory.h" #include "kernel/ps/sparse_apply_adam_ps_kernel.h" #include "kernel/ps/sparse_apply_ftrl_ps_kernel.h" #include "kernel/ps/apply_momentum_ps_kernel.h" #include "kernel/ps/embedding_look_up_ps_kernel.h" namespace mindspore { namespace parallel { namespace ps { using mindspore::kernel::ps::PServerKernel; template class ParameterServer { public: static ParameterServer &GetInstance() { static ParameterServer instance; return instance; } void Run(const FuncGraphPtr &func_graph); private: ParameterServer() : pserver_num_(0), worker_num_(0), rank_id_(0), grad_accum_count_(0), ps_(new ::ps::KVServer(0)), handler_(nullptr), func_graph_(nullptr), kernel_graph_(nullptr), sess_(nullptr), thread_(nullptr) {} ~ParameterServer() = default; ParameterServer(const ParameterServer &) = delete; ParameterServer &operator=(const ParameterServer &) = delete; struct ServerHandler { explicit ServerHandler(ParameterServer *ps) : ps_(ps) {} void operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVServer *server); void HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data); void HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); void HandleInitWeights(const ::ps::KVPairs &req_data); void HandleInitWeightToOptimId(const ::ps::KVPairs &req_data); void HandleInitInputsShape(const ::ps::KVPairs &req_data); void HandleInitEmbeddings(const ::ps::KVPairs &req_data); void HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res); ParameterServer *ps_; }; bool Init(const FuncGraphPtr &func_graph); void InitOptimInfoBuilders(); void InitWeightKeyToOptims(const Key &key, const int &optim_id); void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths); void InitWeight(const Key &key, const WeightPtr &weight); void InitGrad(const Key &key, const GradPtr &grad); void InitEmbeddingTable(const Key &key, const std::shared_ptr>>> &shapes); void UpdateWeights(); void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); WeightPtr weight(const Key &key); void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs *res); int SumOfShapes(const std::vector &shapes) const; size_t PreComputeCapacity(const Keys &keys, const Lengths &lens); bool ReadyForUpdateWeights(); bool ReadyForAccumGrads(); void ResetGradAccumCount(); size_t pserver_num_; size_t worker_num_; size_t rank_id_; size_t grad_accum_count_; std::unique_ptr<::ps::KVServer> ps_; std::unique_ptr handler_; FuncGraphPtr func_graph_; std::shared_ptr kernel_graph_; std::shared_ptr sess_; std::unordered_map> optimizers_; std::unordered_map optim_inputs_shape_; std::unordered_map> optim_infos_; std::unordered_map> optim_info_builders_; std::unordered_map weight_key_to_optims_; std::unordered_map weights_; std::unordered_map grads_; std::unordered_map grads_accum_counter_; // std::unordered_map embeddings_; std::unordered_map> embedding_lookup_ops_; std::unordered_map embedding_row_lens_; T learning_rate_; T momentum_; std::mutex mutex_; std::condition_variable apply_grads_cv_; std::condition_variable accum_grads_cv_; std::unique_ptr thread_; friend struct ServerHandler; }; class FuncGraph; template void ParameterServer::ServerHandler::operator()(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVServer *server) { ::ps::KVPairs res; if (req_meta.cmd == kInitWeightsCmd) { MS_LOG(ERROR) << "handle init weights cmd" << std::endl; HandleInitWeights(req_data); } else if (req_meta.cmd == kInitWeightToOptimIdCmd) { MS_LOG(ERROR) << "handle init weight optim id mapping cmd" << std::endl; HandleInitWeightToOptimId(req_data); } else if (req_meta.cmd == kInitOptimInputsShapeCmd) { MS_LOG(ERROR) << "handle init inputs shape cmd" << std::endl; HandleInitInputsShape(req_data); } else if (req_meta.cmd == kInitEmbeddingsCmd) { MS_LOG(ERROR) << "handle init embedding cmd" << std::endl; HandleInitEmbeddings(req_data); } else if (req_meta.cmd == kEmbeddingLookupCmd) { MS_LOG(ERROR) << "handle embedding lookup cmd" << std::endl; HandleEmbeddingLookup(req_meta, req_data, &res); } else if (req_meta.push) { MS_LOG(ERROR) << "handle push req cmd" << std::endl; HandlePushReq(req_meta, req_data); } else { MS_LOG(ERROR) << "handle pull req cmd" << std::endl; HandlePullReq(req_meta, req_data, &res); } server->Response(req_meta, res); } template void ParameterServer::ServerHandler::HandlePushReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data) { ps_->AccumGrad(req_data.keys, req_data.vals, req_data.lens); } template void ParameterServer::ServerHandler::HandlePullReq(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { res->keys = req_data.keys; ::ps::Key key = req_data.keys[0]; res->vals = *(ps_->weight(key)); } template void ParameterServer::ServerHandler::HandleInitWeights(const ::ps::KVPairs &req_data) { size_t key_num = req_data.keys.size(); T *data_ptr = req_data.vals.data(); size_t pos = 0; for (size_t i = 0; i < key_num; i++) { Key key = req_data.keys[i]; size_t data_len = req_data.lens.size() != key_num ? req_data.vals.size() / key_num : req_data.lens[i]; WeightPtr weight_ptr = std::make_shared<::ps::SArray>(); weight_ptr->CopyFrom(data_ptr + pos, data_len); ps_->InitWeight(key, weight_ptr); GradPtr grad_ptr = std::make_shared<::ps::SArray>(data_len, 0); ps_->InitGrad(key, grad_ptr); pos += data_len; } } template void ParameterServer::ServerHandler::HandleInitWeightToOptimId(const ::ps::KVPairs &req_data) { size_t key_num = req_data.keys.size(); for (size_t i = 0; i < key_num; i++) { Key key = req_data.keys[i]; T val = req_data.vals[i]; ps_->InitWeightKeyToOptims(key, val); } } template void ParameterServer::ServerHandler::HandleInitInputsShape(const ::ps::KVPairs &req_data) { ps_->InitOptimInputsShape(req_data.keys, req_data.vals, req_data.lens); } template void ParameterServer::ServerHandler::HandleInitEmbeddings(const ::ps::KVPairs &req_data) { std::shared_ptr>>> shapes = std::make_shared>>>(); std::shared_ptr> input_shape = std::make_shared>(); std::shared_ptr> indices_shape = std::make_shared>(); std::shared_ptr> output_shape = std::make_shared>(); shapes->push_back(input_shape); shapes->push_back(indices_shape); shapes->push_back(output_shape); const Key &key = req_data.keys[0]; const Lengths &lens = req_data.lens; size_t index = 0; for (int i = 0; i < lens[0]; i++) { input_shape->push_back(static_cast(req_data.vals[index++])); } for (int j = 0; j < lens[1]; j++) { indices_shape->push_back(static_cast(req_data.vals[index++])); } for (int k = 0; k < lens[2]; k++) { output_shape->push_back(static_cast(req_data.vals[index++])); } ps_->InitEmbeddingTable(key, shapes); } template void ParameterServer::ServerHandler::HandleEmbeddingLookup(const ::ps::KVMeta &req_meta, const ::ps::KVPairs &req_data, ::ps::KVPairs *res) { const Key &key = req_data.keys[0]; ps_->DoEmbeddingLookup(key, req_data.vals, res); for (size_t i = 0; i < req_data.vals.size(); i++) { res->keys->push_back(req_data.vals[i]); } } template bool ParameterServer::Init(const FuncGraphPtr &func_graph) { const char *server_num = getenv(kEnvPServerNum); const char *worker_num = getenv(kEnvWorkerNum); if (server_num != nullptr) { pserver_num_ = *server_num - '0'; } if (worker_num != nullptr) { worker_num_ = *worker_num - '0'; } func_graph_ = func_graph; rank_id_ = ::ps::MyRank(); handler_.reset(new ServerHandler(this)); InitOptimInfoBuilders(); ps_->set_request_handle(*handler_); thread_.reset(new std::thread(&ParameterServer::UpdateWeights, this)); return true; } template void ParameterServer::InitOptimInfoBuilders() { std::shared_ptr momentum_info_builder = std::make_shared(); std::shared_ptr sparse_adam_info_builder = std::make_shared(); std::shared_ptr sparse_ftrl_info_builder = std::make_shared(); optim_info_builders_[kApplyMomentum] = momentum_info_builder; optim_info_builders_[kSparseAdam] = sparse_adam_info_builder; optim_info_builders_[kSparseFtrl] = sparse_ftrl_info_builder; } template void ParameterServer::InitWeightKeyToOptims(const Key &key, const int &optim_id) { if (weight_key_to_optims_.count(key) > 0 || Util::optimizer_name(key) == "") { return; } weight_key_to_optims_[key] = Util::optimizer_name(optim_id); } template void ParameterServer::InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths) { InputsShapePtr inputs_shape = std::make_shared(); int val_idx = 0; const Key &key = keys[0]; if (optim_inputs_shape_.count(key) == 0) { optim_inputs_shape_[key] = inputs_shape; } for (size_t i = 0; i < keys.size(); i++) { auto shape = std::make_shared>(); inputs_shape->push_back(shape); int len = lengths[i]; for (int j = 0; j < len; j++) { shape->push_back(values[val_idx++]); } } if (weight_key_to_optims_.count(key) > 0) { const std::string &optim_name = weight_key_to_optims_[key]; if (optimizers_.count(optim_name) == 0 && optim_inputs_shape_.count(key) > 0) { if (optim_name == kSparseAdam) { std::shared_ptr optimizer = std::make_shared(rank_id_, pserver_num_); optimizer->InitKernel(optim_inputs_shape_[key]); optimizers_[optim_name] = optimizer; } else if (optim_name == kApplyMomentum) { std::shared_ptr optimizer = std::make_shared(rank_id_, pserver_num_); optimizer->InitKernel(optim_inputs_shape_[key]); optimizers_[optim_name] = optimizer; } else if (optim_name == kSparseFtrl) { std::shared_ptr optimizer = std::make_shared(rank_id_, pserver_num_); optimizer->InitKernel(optim_inputs_shape_[key]); optimizers_[optim_name] = optimizer; } } } } template void ParameterServer::InitWeight(const Key &key, const WeightPtr &weight) { if (weights_.count(key) == 0) { weights_[key] = weight; } } template void ParameterServer::InitGrad(const Key &key, const GradPtr &grad) { if (grads_.count(key) == 0) { grads_[key] = grad; grads_accum_counter_[key] = 0; } } template void ParameterServer::InitEmbeddingTable( const Key &key, const std::shared_ptr>>> &shapes) { // Init embedding lookup kernel std::shared_ptr lookup = std::make_shared(rank_id_, pserver_num_); lookup->InitKernel(shapes); embedding_lookup_ops_[key] = lookup; // Init embedding weight const std::vector &input_shapes = lookup->input_sizes(); size_t total_dims = 1; for (auto shape : input_shapes) { total_dims *= shape; } WeightPtr embedding = std::make_shared(total_dims, 0.01); weights_[key] = embedding; grads_accum_counter_[key] = 0; } template void ParameterServer::UpdateWeights() { while (true) { std::unique_lock lock(mutex_); apply_grads_cv_.wait(lock, [this] { return this->ReadyForUpdateWeights(); }); for (auto iter = weights_.begin(); iter != weights_.end(); iter++) { Key key = iter->first; WeightPtr weight_ptr = iter->second; std::shared_ptr optimizer = nullptr; if (weight_key_to_optims_.count(key) > 0) { const std::string &optim_name = weight_key_to_optims_[key]; optimizer = optimizers_[optim_name]; } MS_EXCEPTION_IF_NULL(optimizer); std::shared_ptr optim_info = optim_infos_[key]; if (optim_info == nullptr) { continue; } const WeightPtr &weight = weights_[key]; optim_info->UpdateWeight(weight); const std::vector &inputs = optim_info->inputs(); const std::vector &workspaces = optim_info->workspaces(); const std::vector &outputs = optim_info->outputs(); optimizer->Execute(inputs, workspaces, outputs); optim_info->Reset(); } ResetGradAccumCount(); accum_grads_cv_.notify_all(); } } template void ParameterServer::AccumGrad(const Keys &keys, const Values &values, const Lengths &lengths) { std::unique_lock lock(mutex_); accum_grads_cv_.wait(lock, [this] { return this->ReadyForAccumGrads(); }); const Key &key = keys[0]; std::shared_ptr optim_info = optim_infos_[key]; // Create or update the optimizer info if (optim_info == nullptr) { const std::shared_ptr &builder = optim_info_builders_[weight_key_to_optims_[key]]; std::shared_ptr pserver_kernel = optimizers_[weight_key_to_optims_[key]]; if (pserver_kernel == nullptr) { MS_LOG(EXCEPTION) << "no optimizer found for key " << key << " optim name " << weight_key_to_optims_[key]; } MS_EXCEPTION_IF_NULL(pserver_kernel); OptimizerInfo *optim = builder->Build(pserver_kernel, weights_[key], keys, values, lengths, optim_inputs_shape_[key], worker_num_); optim_info.reset(optim); optim_infos_[key] = optim_info; } else { optim_info->Update(values, lengths); } MS_EXCEPTION_IF_NULL(optim_info); optim_info->Accumulate(values, lengths); grads_accum_counter_[key] += 1; if (grads_accum_counter_[key] == worker_num_) { grad_accum_count_++; } if (ReadyForUpdateWeights()) { apply_grads_cv_.notify_one(); } } template WeightPtr ParameterServer::weight(const Key &key) { std::unique_lock lock(mutex_); if (weights_.count(key) == 0) { MS_LOG(ERROR) << "Invalid weight key " << key; return nullptr; } WeightPtr weight_ptr = weights_[key]; WeightPtr copy_weight_ptr = std::make_shared<::ps::SArray>(weight_ptr->size(), 0); copy_weight_ptr->CopyFrom(weight_ptr->data(), weight_ptr->size()); return copy_weight_ptr; } template void ParameterServer::DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, ::ps::KVPairs *res) { std::unique_lock lock(mutex_); if (weights_.count(key) == 0) { MS_LOG(ERROR) << "Invalid embedding table key " << key; return; } if (embedding_lookup_ops_.count(key) == 0) { MS_LOG(ERROR) << "Invalid embedding lookup op key " << key; return; } WeightPtr table_ptr = weights_[key]; std::shared_ptr table_lookup_op = embedding_lookup_ops_[key]; // Update shapes of lookup operator std::shared_ptr>>> shapes = std::make_shared>>>(); std::shared_ptr> indices_shape = std::make_shared>(); indices_shape->emplace_back(lookup_ids.size()); shapes->push_back(indices_shape); table_lookup_op->ReInit(shapes); const std::vector output_shapes = table_lookup_op->output_sizes(); std::vector inputs; AddressPtr embedding_table = std::make_shared(); AddressPtr indices = std::make_shared(); inputs.push_back(embedding_table); inputs.push_back(indices); embedding_table->addr = table_ptr->data(); embedding_table->size = table_ptr->size() * sizeof(T); indices->addr = lookup_ids.data(); indices->size = lookup_ids.size() * sizeof(T); std::vector workspaces; std::vector outputs; AddressPtr output = std::make_shared(); std::shared_ptr addr = std::make_shared(output_shapes[0] / sizeof(T), 0); output->addr = addr->data(); output->size = output_shapes[0]; outputs.push_back(output); table_lookup_op->Execute(inputs, workspaces, outputs); res->vals = *addr; res->lens.push_back(res.vals.size()); } template int ParameterServer::SumOfShapes(const std::vector &shapes) const { int sum = 1; for (auto shape : shapes) { sum *= shape; } return sum; } template size_t ParameterServer::PreComputeCapacity(const Keys &keys, const Lengths &lens) { size_t capacity = 0; for (size_t i = 0; i < keys.size(); i++) { Key key = keys[i]; if (embedding_row_lens_.count(key) > 0) { capacity += embedding_row_lens_[key] * lens[i]; } else { MS_LOG(ERROR) << "Invalid embedding lookup id " << key; } } return capacity; } template inline bool ParameterServer::ReadyForUpdateWeights() { return grads_accum_counter_.size() > 0 && grad_accum_count_ == grads_accum_counter_.size(); } template inline bool ParameterServer::ReadyForAccumGrads() { return grad_accum_count_ < weights_.size(); } template inline void ParameterServer::ResetGradAccumCount() { grad_accum_count_ = 0; for (auto iter = grads_accum_counter_.begin(); iter != grads_accum_counter_.end(); iter++) { grads_accum_counter_[iter->first] = 0; } } template void ParameterServer::Run(const FuncGraphPtr &func_graph) { ::ps::Start(0); if (!::ps::IsServer()) { std::cout << "This is not ther Server" << std::endl; return; } Init(func_graph); thread_->join(); } } // namespace ps } // namespace parallel } // namespace mindspore #endif // MINDSPORE_MINDSPORE_CCSRC_PARALLEL_PS_PARAMETER_SERVER_H_