From 5bdd2eb18c3a25402d7c683031a15970f8ba3790 Mon Sep 17 00:00:00 2001 From: gaoyong10 Date: Wed, 20 Jan 2021 15:48:00 +0800 Subject: [PATCH] ps cache parse linear --- .../ccsrc/ps/ps_cache/embedding_hash_map.cc | 83 ++++++++++++++----- .../ccsrc/ps/ps_cache/embedding_hash_map.h | 23 ++++- .../ccsrc/ps/ps_cache/ps_cache_manager.cc | 31 +++++-- .../ccsrc/ps/ps_cache/ps_cache_manager.h | 3 + 4 files changed, 108 insertions(+), 32 deletions(-) diff --git a/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.cc b/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.cc index 4f5c298b90..75775326e0 100755 --- a/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.cc +++ b/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.cc @@ -19,36 +19,66 @@ namespace mindspore { namespace ps { int EmbeddingHashMap::ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step, - const size_t graph_running_step, size_t *swap_out_size) { + const size_t graph_running_step, size_t *swap_out_size, bool *need_wait_graph) { MS_EXCEPTION_IF_NULL(swap_out_index); MS_EXCEPTION_IF_NULL(swap_out_ids); MS_EXCEPTION_IF_NULL(swap_out_size); - auto hash_index = Hash(id); - auto need_swap = NeedSwap(); - size_t loop = 0; - while (true) { - if (loop++ == hash_capacity_) { - return INVALID_INDEX_VALUE; - } - if (hash_map_elements_[hash_index].IsEmpty()) { + bool need_swap = false; + auto hash_index = FindInsertionPos(data_step, graph_running_step, &need_swap, need_wait_graph); + if (hash_index == INVALID_INDEX_VALUE) { + return hash_index; + } + + if (!need_swap) { + hash_count_++; + (void)hash_id_to_index_.emplace(id, hash_index); + hash_map_elements_[hash_index].set_id(id); + hash_map_elements_[hash_index].set_step(data_step); + return hash_index; + } + + swap_out_index[*swap_out_size] = hash_index; + swap_out_ids[*swap_out_size] = hash_map_elements_[hash_index].id_; + (*swap_out_size)++; + (void)hash_id_to_index_.erase(hash_map_elements_[hash_index].id_); + (void)hash_id_to_index_.emplace(id, hash_index); + hash_map_elements_[hash_index].set_id(id); + hash_map_elements_[hash_index].set_step(data_step); + return hash_index; +} + +int EmbeddingHashMap::FindInsertionPos(const size_t data_step, const size_t graph_running_step, bool *need_swap, + bool *need_wait_graph) { + MS_EXCEPTION_IF_NULL(need_swap); + MS_EXCEPTION_IF_NULL(need_wait_graph); + int hash_index = INVALID_INDEX_VALUE; + while (!expired_element_full_) { + if (hash_map_elements_[current_pos_].IsEmpty()) { + hash_index = current_pos_; hash_count_++; - (void)hash_id_to_index_.emplace(id, hash_index); - hash_map_elements_[hash_index].set_id(id); - hash_map_elements_[hash_index].set_step(data_step); - return hash_index; - } else if (need_swap && hash_map_elements_[hash_index].IsExpired(graph_running_step)) { - // Need swap out from the hash table. - swap_out_index[*swap_out_size] = hash_index; - swap_out_ids[*swap_out_size] = hash_map_elements_[hash_index].id_; - (*swap_out_size)++; - (void)hash_id_to_index_.erase(hash_map_elements_[hash_index].id_); - (void)hash_id_to_index_.emplace(id, hash_index); - hash_map_elements_[hash_index].set_id(id); - hash_map_elements_[hash_index].set_step(data_step); + } else if (hash_map_elements_[current_pos_].IsExpired(graph_running_step)) { + hash_index = current_pos_; + *need_swap = true; + } else if (hash_map_elements_[current_pos_].IsStep(graph_running_step)) { + graph_running_index_[graph_running_index_num_++] = current_pos_; + } + current_pos_ = (current_pos_ + 1) % hash_capacity_; + if (hash_index != INVALID_INDEX_VALUE) { return hash_index; } - hash_index = (hash_index + 1) % hash_capacity_; + if (current_pos_ == current_batch_start_pos_) { + expired_element_full_ = true; + MS_LOG(INFO) << "Running step:" << graph_running_step << "(num:" << graph_running_index_num_ + << ") will be used, index swap will wait until the graph completed."; + } } + + if (graph_running_index_pos_ != graph_running_index_num_) { + *need_swap = true; + *need_wait_graph = true; + return graph_running_index_[graph_running_index_pos_++]; + } + return INVALID_INDEX_VALUE; } void EmbeddingHashMap::DumpHashMap() { @@ -66,5 +96,12 @@ void EmbeddingHashMap::DumpHashMap() { } MS_LOG(INFO) << "Dump hash map info end."; } + +void EmbeddingHashMap::Reset() { + current_batch_start_pos_ = current_pos_; + graph_running_index_num_ = 0; + graph_running_index_pos_ = 0; + expired_element_full_ = false; +} } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.h b/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.h index 47dca4ca85..0fe17a5d73 100644 --- a/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.h +++ b/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.h @@ -34,6 +34,7 @@ struct HashMapElement { size_t step_{INVALID_STEP_VALUE}; bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; } bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; } + bool IsStep(size_t step) const { return step_ == step; } void set_id(int id) { id_ = id; } void set_step(size_t step) { step_ = step; } }; @@ -41,25 +42,39 @@ struct HashMapElement { // Hash table is held in device, HashMap is used to manage hash table in host. class EmbeddingHashMap { public: - EmbeddingHashMap(size_t hash_count, size_t hash_capacity) : hash_count_(hash_count), hash_capacity_(hash_capacity) { + EmbeddingHashMap(size_t hash_count, size_t hash_capacity) + : hash_count_(hash_count), + hash_capacity_(hash_capacity), + current_pos_(0), + current_batch_start_pos_(0), + graph_running_index_num_(0), + graph_running_index_pos_(0), + expired_element_full_(false) { hash_map_elements_.resize(hash_capacity); + graph_running_index_ = std::make_unique(hash_capacity); } virtual ~EmbeddingHashMap() = default; int ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step, - const size_t graph_running_step, size_t *swap_out_size); + const size_t graph_running_step, size_t *swap_out_size, bool *need_wait_graph); size_t hash_step(const int hash_index) const { return hash_map_elements_[hash_index].step_; } void set_hash_step(const int hash_index, const size_t step) { hash_map_elements_[hash_index].set_step(step); } const std::unordered_map &hash_id_to_index() const { return hash_id_to_index_; } size_t hash_capacity() const { return hash_capacity_; } void DumpHashMap(); + void Reset(); private: - int Hash(const int id) { return static_cast((0.6180339 * id - std::floor(0.6180339 * id)) * hash_capacity_); } - bool NeedSwap() const { return hash_count_ > FloatToSize(hash_capacity_ * 0.9); } + int FindInsertionPos(const size_t data_step, const size_t graph_running_step, bool *need_swap, bool *need_wait_graph); size_t hash_count_; size_t hash_capacity_; std::vector hash_map_elements_; std::unordered_map hash_id_to_index_; + size_t current_pos_; + size_t current_batch_start_pos_; + size_t graph_running_index_num_; + size_t graph_running_index_pos_; + std::unique_ptr graph_running_index_; + bool expired_element_full_; }; } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc index 415f51c029..7f7c09368c 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc @@ -369,6 +369,10 @@ bool PsCacheManager::ProcessData() { // Get hash swap in/out index and ids. RETURN_IF_FALSE(ParseData(batch_ids, batch_ids_len, hash_index.get())); DumpStatisticsInfo(); + if ((device_need_wait_graph_ || host_need_wait_graph_) && (!WaitGraphRun())) { + MS_LOG(ERROR) << "Ps cache wait graph finish failed."; + return false; + } for (const auto &item : hash_tables_) { auto key = worker.GetParamKey(item.first); auto hash_info = item.second; @@ -454,6 +458,20 @@ bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_id return true; } +bool PsCacheManager::ResetEmbeddingHashMap() { + MS_ERROR_IF_NULL(embedding_device_cache_); + const auto &device_hash_map = embedding_device_cache_->device_hash_map_; + MS_ERROR_IF_NULL(device_hash_map); + MS_ERROR_IF_NULL(embedding_host_cache_); + const auto &host_hash_map = embedding_host_cache_->host_hash_map_; + MS_ERROR_IF_NULL(host_hash_map); + device_hash_map->Reset(); + host_hash_map->Reset(); + device_need_wait_graph_ = false; + host_need_wait_graph_ = false; + return true; +} + bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) { MS_ERROR_IF_NULL(batch_ids); MS_ERROR_IF_NULL(hash_index); @@ -463,6 +481,7 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, MS_LOG(EXCEPTION) << "Data in device memset failed."; } CheckIDInDevice(batch_ids, batch_ids_len, hash_index, in_device.get()); + RETURN_IF_FALSE(ResetEmbeddingHashMap()); for (size_t i = 0; i < batch_ids_len; i++) { if (in_device[i]) { continue; @@ -529,7 +548,7 @@ bool PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, auto tmp_device_to_host_size = statistics_info_.device_to_host_size_; while (true) { index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_, - &(statistics_info_.device_to_host_size_)); + &(statistics_info_.device_to_host_size_), &device_need_wait_graph_); if (index == INVALID_INDEX_VALUE) { if (!WaitGraphRun()) { return false; @@ -570,8 +589,9 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) { MS_ERROR_IF_NULL(server_to_host_index); MS_ERROR_IF_NULL(server_to_host_ids); while (true) { - auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, - graph_running_step_, &statistics_info_.host_to_server_size_); + auto index = + host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, graph_running_step_, + &statistics_info_.host_to_server_size_, &host_need_wait_graph_); if (index == INVALID_INDEX_VALUE) { RETURN_IF_FALSE(WaitGraphRun()); continue; @@ -607,8 +627,9 @@ bool PsCacheManager::ParseHostDataDeviceToHost() { int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); while (true) { - auto index = host_hash_map->ParseData(swap_device_to_host_id, host_to_server_index, host_to_server_ids, - data_step_, graph_running_step_, &statistics_info_.host_to_server_size_); + auto index = + host_hash_map->ParseData(swap_device_to_host_id, host_to_server_index, host_to_server_ids, data_step_, + graph_running_step_, &statistics_info_.host_to_server_size_, &host_need_wait_graph_); if (index == INVALID_INDEX_VALUE) { RETURN_IF_FALSE(WaitGraphRun()); continue; diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h index 851f2b6c57..f67bcc8610 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h @@ -173,6 +173,7 @@ class PsCacheManager { bool CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device, size_t *hash_hit_count); bool CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device); + bool ResetEmbeddingHashMap(); bool initialized_ps_cache_{false}; std::string channel_name_; std::mutex channel_mutex_; @@ -198,6 +199,8 @@ class PsCacheManager { std::atomic_bool finish_init_parameter_server_{false}; std::atomic_bool running_{false}; bool finish_embedding_table_sync_{false}; + bool device_need_wait_graph_{false}; + bool host_need_wait_graph_{false}; }; static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance();