From: @gaoyong10 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -19,36 +19,66 @@ | |||
| namespace mindspore { | |||
| namespace ps { | |||
| int EmbeddingHashMap::ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step, | |||
| const size_t graph_running_step, size_t *swap_out_size) { | |||
| const size_t graph_running_step, size_t *swap_out_size, bool *need_wait_graph) { | |||
| MS_EXCEPTION_IF_NULL(swap_out_index); | |||
| MS_EXCEPTION_IF_NULL(swap_out_ids); | |||
| MS_EXCEPTION_IF_NULL(swap_out_size); | |||
| auto hash_index = Hash(id); | |||
| auto need_swap = NeedSwap(); | |||
| size_t loop = 0; | |||
| while (true) { | |||
| if (loop++ == hash_capacity_) { | |||
| return INVALID_INDEX_VALUE; | |||
| } | |||
| if (hash_map_elements_[hash_index].IsEmpty()) { | |||
| bool need_swap = false; | |||
| auto hash_index = FindInsertionPos(data_step, graph_running_step, &need_swap, need_wait_graph); | |||
| if (hash_index == INVALID_INDEX_VALUE) { | |||
| return hash_index; | |||
| } | |||
| if (!need_swap) { | |||
| hash_count_++; | |||
| (void)hash_id_to_index_.emplace(id, hash_index); | |||
| hash_map_elements_[hash_index].set_id(id); | |||
| hash_map_elements_[hash_index].set_step(data_step); | |||
| return hash_index; | |||
| } | |||
| swap_out_index[*swap_out_size] = hash_index; | |||
| swap_out_ids[*swap_out_size] = hash_map_elements_[hash_index].id_; | |||
| (*swap_out_size)++; | |||
| (void)hash_id_to_index_.erase(hash_map_elements_[hash_index].id_); | |||
| (void)hash_id_to_index_.emplace(id, hash_index); | |||
| hash_map_elements_[hash_index].set_id(id); | |||
| hash_map_elements_[hash_index].set_step(data_step); | |||
| return hash_index; | |||
| } | |||
| int EmbeddingHashMap::FindInsertionPos(const size_t data_step, const size_t graph_running_step, bool *need_swap, | |||
| bool *need_wait_graph) { | |||
| MS_EXCEPTION_IF_NULL(need_swap); | |||
| MS_EXCEPTION_IF_NULL(need_wait_graph); | |||
| int hash_index = INVALID_INDEX_VALUE; | |||
| while (!expired_element_full_) { | |||
| if (hash_map_elements_[current_pos_].IsEmpty()) { | |||
| hash_index = current_pos_; | |||
| hash_count_++; | |||
| (void)hash_id_to_index_.emplace(id, hash_index); | |||
| hash_map_elements_[hash_index].set_id(id); | |||
| hash_map_elements_[hash_index].set_step(data_step); | |||
| return hash_index; | |||
| } else if (need_swap && hash_map_elements_[hash_index].IsExpired(graph_running_step)) { | |||
| // Need swap out from the hash table. | |||
| swap_out_index[*swap_out_size] = hash_index; | |||
| swap_out_ids[*swap_out_size] = hash_map_elements_[hash_index].id_; | |||
| (*swap_out_size)++; | |||
| (void)hash_id_to_index_.erase(hash_map_elements_[hash_index].id_); | |||
| (void)hash_id_to_index_.emplace(id, hash_index); | |||
| hash_map_elements_[hash_index].set_id(id); | |||
| hash_map_elements_[hash_index].set_step(data_step); | |||
| } else if (hash_map_elements_[current_pos_].IsExpired(graph_running_step)) { | |||
| hash_index = current_pos_; | |||
| *need_swap = true; | |||
| } else if (hash_map_elements_[current_pos_].IsStep(graph_running_step)) { | |||
| graph_running_index_[graph_running_index_num_++] = current_pos_; | |||
| } | |||
| current_pos_ = (current_pos_ + 1) % hash_capacity_; | |||
| if (hash_index != INVALID_INDEX_VALUE) { | |||
| return hash_index; | |||
| } | |||
| hash_index = (hash_index + 1) % hash_capacity_; | |||
| if (current_pos_ == current_batch_start_pos_) { | |||
| expired_element_full_ = true; | |||
| MS_LOG(INFO) << "Running step:" << graph_running_step << "(num:" << graph_running_index_num_ | |||
| << ") will be used, index swap will wait until the graph completed."; | |||
| } | |||
| } | |||
| if (graph_running_index_pos_ != graph_running_index_num_) { | |||
| *need_swap = true; | |||
| *need_wait_graph = true; | |||
| return graph_running_index_[graph_running_index_pos_++]; | |||
| } | |||
| return INVALID_INDEX_VALUE; | |||
| } | |||
| void EmbeddingHashMap::DumpHashMap() { | |||
| @@ -66,5 +96,12 @@ void EmbeddingHashMap::DumpHashMap() { | |||
| } | |||
| MS_LOG(INFO) << "Dump hash map info end."; | |||
| } | |||
| void EmbeddingHashMap::Reset() { | |||
| current_batch_start_pos_ = current_pos_; | |||
| graph_running_index_num_ = 0; | |||
| graph_running_index_pos_ = 0; | |||
| expired_element_full_ = false; | |||
| } | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -34,6 +34,7 @@ struct HashMapElement { | |||
| size_t step_{INVALID_STEP_VALUE}; | |||
| bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; } | |||
| bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; } | |||
| bool IsStep(size_t step) const { return step_ == step; } | |||
| void set_id(int id) { id_ = id; } | |||
| void set_step(size_t step) { step_ = step; } | |||
| }; | |||
| @@ -41,25 +42,39 @@ struct HashMapElement { | |||
| // Hash table is held in device, HashMap is used to manage hash table in host. | |||
| class EmbeddingHashMap { | |||
| public: | |||
| EmbeddingHashMap(size_t hash_count, size_t hash_capacity) : hash_count_(hash_count), hash_capacity_(hash_capacity) { | |||
| EmbeddingHashMap(size_t hash_count, size_t hash_capacity) | |||
| : hash_count_(hash_count), | |||
| hash_capacity_(hash_capacity), | |||
| current_pos_(0), | |||
| current_batch_start_pos_(0), | |||
| graph_running_index_num_(0), | |||
| graph_running_index_pos_(0), | |||
| expired_element_full_(false) { | |||
| hash_map_elements_.resize(hash_capacity); | |||
| graph_running_index_ = std::make_unique<int[]>(hash_capacity); | |||
| } | |||
| virtual ~EmbeddingHashMap() = default; | |||
| int ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step, | |||
| const size_t graph_running_step, size_t *swap_out_size); | |||
| const size_t graph_running_step, size_t *swap_out_size, bool *need_wait_graph); | |||
| size_t hash_step(const int hash_index) const { return hash_map_elements_[hash_index].step_; } | |||
| void set_hash_step(const int hash_index, const size_t step) { hash_map_elements_[hash_index].set_step(step); } | |||
| const std::unordered_map<int, int> &hash_id_to_index() const { return hash_id_to_index_; } | |||
| size_t hash_capacity() const { return hash_capacity_; } | |||
| void DumpHashMap(); | |||
| void Reset(); | |||
| private: | |||
| int Hash(const int id) { return static_cast<int>((0.6180339 * id - std::floor(0.6180339 * id)) * hash_capacity_); } | |||
| bool NeedSwap() const { return hash_count_ > FloatToSize(hash_capacity_ * 0.9); } | |||
| int FindInsertionPos(const size_t data_step, const size_t graph_running_step, bool *need_swap, bool *need_wait_graph); | |||
| size_t hash_count_; | |||
| size_t hash_capacity_; | |||
| std::vector<HashMapElement> hash_map_elements_; | |||
| std::unordered_map<int, int> hash_id_to_index_; | |||
| size_t current_pos_; | |||
| size_t current_batch_start_pos_; | |||
| size_t graph_running_index_num_; | |||
| size_t graph_running_index_pos_; | |||
| std::unique_ptr<int[]> graph_running_index_; | |||
| bool expired_element_full_; | |||
| }; | |||
| } // namespace ps | |||
| } // namespace mindspore | |||
| @@ -369,6 +369,10 @@ bool PsCacheManager::ProcessData() { | |||
| // Get hash swap in/out index and ids. | |||
| RETURN_IF_FALSE(ParseData(batch_ids, batch_ids_len, hash_index.get())); | |||
| DumpStatisticsInfo(); | |||
| if ((device_need_wait_graph_ || host_need_wait_graph_) && (!WaitGraphRun())) { | |||
| MS_LOG(ERROR) << "Ps cache wait graph finish failed."; | |||
| return false; | |||
| } | |||
| for (const auto &item : hash_tables_) { | |||
| auto key = worker.GetParamKey(item.first); | |||
| auto hash_info = item.second; | |||
| @@ -454,6 +458,20 @@ bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_id | |||
| return true; | |||
| } | |||
| bool PsCacheManager::ResetEmbeddingHashMap() { | |||
| MS_ERROR_IF_NULL(embedding_device_cache_); | |||
| const auto &device_hash_map = embedding_device_cache_->device_hash_map_; | |||
| MS_ERROR_IF_NULL(device_hash_map); | |||
| MS_ERROR_IF_NULL(embedding_host_cache_); | |||
| const auto &host_hash_map = embedding_host_cache_->host_hash_map_; | |||
| MS_ERROR_IF_NULL(host_hash_map); | |||
| device_hash_map->Reset(); | |||
| host_hash_map->Reset(); | |||
| device_need_wait_graph_ = false; | |||
| host_need_wait_graph_ = false; | |||
| return true; | |||
| } | |||
| bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) { | |||
| MS_ERROR_IF_NULL(batch_ids); | |||
| MS_ERROR_IF_NULL(hash_index); | |||
| @@ -463,6 +481,7 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, | |||
| MS_LOG(EXCEPTION) << "Data in device memset failed."; | |||
| } | |||
| CheckIDInDevice(batch_ids, batch_ids_len, hash_index, in_device.get()); | |||
| RETURN_IF_FALSE(ResetEmbeddingHashMap()); | |||
| for (size_t i = 0; i < batch_ids_len; i++) { | |||
| if (in_device[i]) { | |||
| continue; | |||
| @@ -529,7 +548,7 @@ bool PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, | |||
| auto tmp_device_to_host_size = statistics_info_.device_to_host_size_; | |||
| while (true) { | |||
| index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_, | |||
| &(statistics_info_.device_to_host_size_)); | |||
| &(statistics_info_.device_to_host_size_), &device_need_wait_graph_); | |||
| if (index == INVALID_INDEX_VALUE) { | |||
| if (!WaitGraphRun()) { | |||
| return false; | |||
| @@ -570,8 +589,9 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) { | |||
| MS_ERROR_IF_NULL(server_to_host_index); | |||
| MS_ERROR_IF_NULL(server_to_host_ids); | |||
| while (true) { | |||
| auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, | |||
| graph_running_step_, &statistics_info_.host_to_server_size_); | |||
| auto index = | |||
| host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, graph_running_step_, | |||
| &statistics_info_.host_to_server_size_, &host_need_wait_graph_); | |||
| if (index == INVALID_INDEX_VALUE) { | |||
| RETURN_IF_FALSE(WaitGraphRun()); | |||
| continue; | |||
| @@ -607,8 +627,9 @@ bool PsCacheManager::ParseHostDataDeviceToHost() { | |||
| int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); | |||
| int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); | |||
| while (true) { | |||
| auto index = host_hash_map->ParseData(swap_device_to_host_id, host_to_server_index, host_to_server_ids, | |||
| data_step_, graph_running_step_, &statistics_info_.host_to_server_size_); | |||
| auto index = | |||
| host_hash_map->ParseData(swap_device_to_host_id, host_to_server_index, host_to_server_ids, data_step_, | |||
| graph_running_step_, &statistics_info_.host_to_server_size_, &host_need_wait_graph_); | |||
| if (index == INVALID_INDEX_VALUE) { | |||
| RETURN_IF_FALSE(WaitGraphRun()); | |||
| continue; | |||
| @@ -173,6 +173,7 @@ class PsCacheManager { | |||
| bool CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device, | |||
| size_t *hash_hit_count); | |||
| bool CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device); | |||
| bool ResetEmbeddingHashMap(); | |||
| bool initialized_ps_cache_{false}; | |||
| std::string channel_name_; | |||
| std::mutex channel_mutex_; | |||
| @@ -198,6 +199,8 @@ class PsCacheManager { | |||
| std::atomic_bool finish_init_parameter_server_{false}; | |||
| std::atomic_bool running_{false}; | |||
| bool finish_embedding_table_sync_{false}; | |||
| bool device_need_wait_graph_{false}; | |||
| bool host_need_wait_graph_{false}; | |||
| }; | |||
| static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance(); | |||