/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef MINDSPORE_CCSRC_PS_WORKER_H_ #define MINDSPORE_CCSRC_PS_WORKER_H_ #include #include #include #include #include #include #include #include #include #include #include #include "utils/log_adapter.h" #include "ir/tensor.h" #include "ps/util.h" #include "ps/constants.h" #include "utils/shape_utils.h" #include "ps/ps_cache/ps_data/ps_data_prefetch.h" #include "ps/core/worker_node.h" #include "ps/embedding_table_shard_metadata.h" #include "proto/comm.pb.h" #include "proto/ps.pb.h" #include "ps/ps_context.h" namespace mindspore { namespace ps { class Worker { public: static Worker &GetInstance() { static Worker instance; return instance; } using Callback = std::function; using PartitionEmbeddingMessages = std::vector>; using PartitionKVMessages = std::vector>; using EmbeddingPartitioner = std::function &attrs)>; using KVPartitioner = std::function &attrs)>; void Run(); void Push(const std::vector &keys, std::vector addrs, const ShapeVector &sizes); void Pull(const size_t key, void *dev_addr, const size_t size); size_t SetParamKey(const std::string ¶m_name); size_t GetParamKey(const std::string ¶m_name); void SetParamInitInServer(const std::string ¶m_name, bool init_in_server); bool GetParamInitInServer(const std::string ¶m_name); void SetKeyOptimId(size_t key, const std::string &optimizer_name); void SetOptimInputShapes(size_t key, const ShapeVector &shape); void AddEmbeddingTable(const Key &key, const size_t &row_count); void InitPSEmbeddingTable(const size_t &key, const std::vector &input_shape, const std::vector &indices_shape, const std::vector &output_shape, const ParamInitInfoMessage &info); void InitPSParamAndOptim(const AnfNodePtr &input_node, const tensor::TensorPtr &tensor); void DoPSEmbeddingLookup(const Key &key, const std::vector &lookup_ids, std::vector *lookup_result, int64_t cmd); void UpdateEmbeddingTable(const std::vector &keys, const std::vector &lookup_ids, const std::vector &vals); bool running() { return running_; } void Finalize(); private: Worker() : server_num_(-1), running_(false), key_cnt_(0) {} ~Worker() = default; Worker(const Worker &) = delete; Worker &operator=(const Worker &) = delete; void Initialize(); bool IsKeyInit(const size_t key); void AddKeyToServerId(const Key &key); void AddKeyByHashMod(const Key &key); void InitPSOptimId(const size_t param_key); void InitPSOptimInputShapes(const size_t key); void InitPSParamData(const std::vector &keys, void *const origin_addr, size_t size); bool IsReadyForPush(const Key &key); bool IsReadyForPull(const Key &key); void PrepareSparseGradient(const size_t begin, const size_t end, const std::unordered_set &distinct_ids, const std::vector> &indice_to_grads, const int *all_indice, const size_t segment_size, float *gradient, int *indices); void BuildSparseValue(const std::vector &lengths, const size_t grad_index, const size_t indice_index, const float *original_data, const float *grads, int *indices, std::vector *reduced_data); void PushData(const std::vector &keys, const std::vector &vals, const std::vector &lens = {}, int command = 0, int64_t priority = 0); void PushSparseData(const std::vector &keys, const std::vector &vals, const std::vector &lens, size_t grad_index, size_t indice_index, size_t first_dim_size, size_t outer_dim_size); void PullData(const std::vector &keys, std::vector *const vals, std::vector *lens = nullptr, int cmd = 0, int64_t priority = 0); void LookupIdPartitioner(const EmbeddingTableLookup &send, PartitionEmbeddingMessages *partition, const std::map &attrs); void SparsePartitioner(const KVMessage &send, PartitionKVMessages *partition, const std::map &attrs); void RoundRobinPartitioner(const KVMessage &send, PartitionKVMessages *partition, const std::map &attrs); void WorkerInitEmbeddingPartitioner(const KVMessage &send, std::vector> *partition, const std::map &attrs); void UpdateEmbeddingPartitioner(const KVMessage &send, PartitionKVMessages *partition, const std::map &attrs); void BroadcastPartitioner(const KVMessage &send, PartitionKVMessages *partition, const std::map &attrs); void SendForPush(int cmd, const KVMessage &send, const KVPartitioner &partitioner, const std::map &attrs); void SendForPull(int cmd, const KVMessage &send, const KVPartitioner &partitioner, const std::map &attrs, std::vector *vals, std::vector *lens); int64_t server_num_; bool running_; std::mutex running_mutex_; size_t key_cnt_; std::map param_to_key_; std::map init_keys_; std::map key_to_optimId_; std::map> key_to_optim_shapes_; std::map param_to_init_in_server_; core::WorkerNode worker_node_; EmbeddingPartitioner lookup_partitioner_; KVPartitioner sparse_partitioner_; KVPartitioner round_robin_partitioner_; KVPartitioner worker_init_embedding_partitioner_; KVPartitioner update_embedding_partitioner_; KVPartitioner broadcast_partitioner_; std::unordered_map key_to_server_id_; std::unordered_map embedding_row_cnt_; std::unordered_map>> embedding_table_ranges_; }; } // namespace ps } // namespace mindspore #endif // MINDSPORE_CCSRC_PS_WORKER_H_