/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_ #define MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "ir/func_graph.h" #include "backend/session/session_basic.h" #include "backend/session/anf_runtime_algorithm.h" #include "backend/session/session_factory.h" #include "ps/optimizer_info.h" #include "ps/optimizer_info_builder.h" #include "ps/ps_context.h" #include "runtime/device/cpu/kernel_select_cpu.h" #include "utils/ms_context.h" #include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" #include "backend/kernel_compiler/cpu/ps/pserver_kernel.h" #include "backend/kernel_compiler/cpu/ps/sparse_apply_adam_ps_kernel.h" #include "backend/kernel_compiler/cpu/ps/sparse_apply_lazy_adam_ps_kernel.h" #include "backend/kernel_compiler/cpu/ps/sparse_apply_ftrl_ps_kernel.h" #include "backend/kernel_compiler/cpu/ps/apply_momentum_ps_kernel.h" #include "backend/kernel_compiler/cpu/ps/embedding_look_up_ps_kernel.h" #include "ps/ps_cache/ps_data/ps_data_prefetch.h" #include "ps/random_normal/random_normal.h" #include "ps/constants.h" #include "ps/util.h" #include "ps/embedding_table_shard_metadata.h" #include "utils/log_adapter.h" #include "proto/comm.pb.h" #include "proto/ps.pb.h" #include "ps/core/server_node.h" #include "ps/core/node.h" namespace mindspore { namespace ps { class ParameterServer { public: static ParameterServer &GetInstance() { static ParameterServer instance; return instance; } void Run(const FuncGraphPtr &func_graph); private: ParameterServer() : pserver_num_(0), worker_num_(0), grad_accum_count_(0), handler_(nullptr), func_graph_(nullptr), sess_(nullptr), running_(true), thread_(nullptr), server_node_(nullptr) {} ~ParameterServer() = default; ParameterServer(const ParameterServer &) = delete; ParameterServer &operator=(const ParameterServer &) = delete; class ServerHandler { public: explicit ServerHandler(ParameterServer *ps) : ps_(ps) {} ~ServerHandler() = default; void Init(); void operator()(std::shared_ptr conn, std::shared_ptr meta, DataPtr data, size_t size); void HandlePushReq(DataPtr data, size_t size, VectorPtr res); void HandlePullReq(DataPtr data, size_t size, VectorPtr res); void HandleInitWeights(DataPtr data, size_t size, VectorPtr res); void HandleInitWeightToOptimId(DataPtr data, size_t size, VectorPtr res); void HandleInitInputsShape(DataPtr data, size_t size, VectorPtr res); void HandleInitEmbeddings(DataPtr data, size_t size, VectorPtr res); void HandleCheckReadyForPush(DataPtr data, size_t size, VectorPtr res); void HandleCheckReadyForPull(DataPtr data, size_t size, VectorPtr res); void HandleEmbeddingLookup(DataPtr data, size_t size, VectorPtr res); void HandleUpdateEmbeddings(DataPtr data, size_t size, VectorPtr res); void HandleFinalize(DataPtr data, size_t size, VectorPtr res); private: ParameterServer *ps_; typedef void (ServerHandler::*RequestHandler)(DataPtr data, size_t size, VectorPtr res); std::unordered_map handlers_; std::unordered_map commands_; std::unordered_map init_weights_; std::unordered_map init_weight_to_optim_; std::unordered_map init_optim_info_; }; bool Init(const FuncGraphPtr &func_graph); void InitOptimInfoBuilders(); void InitWeightKeyToOptims(const Key &key, const int64_t &optim_id); void InitOptimInputsShape(const Keys &keys, const Values &values, const Lengths &lengths); void InitWeight(const Key &key, const WeightPtr &weight); void InitGrad(const Key &key, const GradPtr &grad); void InitEmbeddingTable(const Key &key, const std::shared_ptr>>> &shapes, const ParamInitInfo ¶m_init_info); bool HasWeight(const Key &key); void Finalize(); void UpdateWeights(); void AccumGrad(const Keys &key, const Values &values, const Lengths &lengths); WeightPtr weight(const Key &key); void DoEmbeddingLookup(Key key, const LookupIds &lookup_ids, KVMessage *res); void UpdateEmbeddings(const Key &key, const LookupIds &lookup_ids, const Values &vals); bool ReadyForUpdateWeights(); bool ReadyForPush(const Key &key); bool ReadyForPull(const Key &key); void ResetGradAccumCount(); const CNodePtr GetCNode(const std::string &name) const; std::mutex &mutex(); void GetEmbeddingTableParamPtr(); void SyncEmbeddingTables(); size_t pserver_num_; size_t worker_num_; size_t grad_accum_count_; std::unique_ptr handler_; FuncGraphPtr func_graph_; std::shared_ptr sess_; bool running_; std::unordered_map> optimizers_; std::unordered_map optim_inputs_shape_; std::unordered_map original_optim_inputs_shape_; std::unordered_map> optim_infos_; std::unordered_map> optim_info_builders_; std::unordered_map weight_key_to_optims_; std::unordered_map weight_key_to_optim_op_; std::unordered_map weights_; std::unordered_map is_embedding_; std::unordered_map grads_; std::unordered_map grads_accum_counter_; std::unordered_map> embedding_lookup_ops_; std::unordered_map tokens_; std::mutex mutex_; std::condition_variable apply_grads_cv_; std::unique_ptr thread_; std::shared_ptr server_node_; std::map embedding_tables_; friend class ServerHandler; }; } // namespace ps } // namespace mindspore #endif // MINDSPORE_CCSRC_PS_PARAMETER_SERVER_H_