From: @gaoyong10 Reviewed-by: Signed-off-by:tags/v1.2.0-rc1
| @@ -19,36 +19,66 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace ps { | namespace ps { | ||||
| int EmbeddingHashMap::ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step, | int EmbeddingHashMap::ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step, | ||||
| const size_t graph_running_step, size_t *swap_out_size) { | |||||
| const size_t graph_running_step, size_t *swap_out_size, bool *need_wait_graph) { | |||||
| MS_EXCEPTION_IF_NULL(swap_out_index); | MS_EXCEPTION_IF_NULL(swap_out_index); | ||||
| MS_EXCEPTION_IF_NULL(swap_out_ids); | MS_EXCEPTION_IF_NULL(swap_out_ids); | ||||
| MS_EXCEPTION_IF_NULL(swap_out_size); | MS_EXCEPTION_IF_NULL(swap_out_size); | ||||
| auto hash_index = Hash(id); | |||||
| auto need_swap = NeedSwap(); | |||||
| size_t loop = 0; | |||||
| while (true) { | |||||
| if (loop++ == hash_capacity_) { | |||||
| return INVALID_INDEX_VALUE; | |||||
| } | |||||
| if (hash_map_elements_[hash_index].IsEmpty()) { | |||||
| bool need_swap = false; | |||||
| auto hash_index = FindInsertionPos(data_step, graph_running_step, &need_swap, need_wait_graph); | |||||
| if (hash_index == INVALID_INDEX_VALUE) { | |||||
| return hash_index; | |||||
| } | |||||
| if (!need_swap) { | |||||
| hash_count_++; | |||||
| (void)hash_id_to_index_.emplace(id, hash_index); | |||||
| hash_map_elements_[hash_index].set_id(id); | |||||
| hash_map_elements_[hash_index].set_step(data_step); | |||||
| return hash_index; | |||||
| } | |||||
| swap_out_index[*swap_out_size] = hash_index; | |||||
| swap_out_ids[*swap_out_size] = hash_map_elements_[hash_index].id_; | |||||
| (*swap_out_size)++; | |||||
| (void)hash_id_to_index_.erase(hash_map_elements_[hash_index].id_); | |||||
| (void)hash_id_to_index_.emplace(id, hash_index); | |||||
| hash_map_elements_[hash_index].set_id(id); | |||||
| hash_map_elements_[hash_index].set_step(data_step); | |||||
| return hash_index; | |||||
| } | |||||
| int EmbeddingHashMap::FindInsertionPos(const size_t data_step, const size_t graph_running_step, bool *need_swap, | |||||
| bool *need_wait_graph) { | |||||
| MS_EXCEPTION_IF_NULL(need_swap); | |||||
| MS_EXCEPTION_IF_NULL(need_wait_graph); | |||||
| int hash_index = INVALID_INDEX_VALUE; | |||||
| while (!expired_element_full_) { | |||||
| if (hash_map_elements_[current_pos_].IsEmpty()) { | |||||
| hash_index = current_pos_; | |||||
| hash_count_++; | hash_count_++; | ||||
| (void)hash_id_to_index_.emplace(id, hash_index); | |||||
| hash_map_elements_[hash_index].set_id(id); | |||||
| hash_map_elements_[hash_index].set_step(data_step); | |||||
| return hash_index; | |||||
| } else if (need_swap && hash_map_elements_[hash_index].IsExpired(graph_running_step)) { | |||||
| // Need swap out from the hash table. | |||||
| swap_out_index[*swap_out_size] = hash_index; | |||||
| swap_out_ids[*swap_out_size] = hash_map_elements_[hash_index].id_; | |||||
| (*swap_out_size)++; | |||||
| (void)hash_id_to_index_.erase(hash_map_elements_[hash_index].id_); | |||||
| (void)hash_id_to_index_.emplace(id, hash_index); | |||||
| hash_map_elements_[hash_index].set_id(id); | |||||
| hash_map_elements_[hash_index].set_step(data_step); | |||||
| } else if (hash_map_elements_[current_pos_].IsExpired(graph_running_step)) { | |||||
| hash_index = current_pos_; | |||||
| *need_swap = true; | |||||
| } else if (hash_map_elements_[current_pos_].IsStep(graph_running_step)) { | |||||
| graph_running_index_[graph_running_index_num_++] = current_pos_; | |||||
| } | |||||
| current_pos_ = (current_pos_ + 1) % hash_capacity_; | |||||
| if (hash_index != INVALID_INDEX_VALUE) { | |||||
| return hash_index; | return hash_index; | ||||
| } | } | ||||
| hash_index = (hash_index + 1) % hash_capacity_; | |||||
| if (current_pos_ == current_batch_start_pos_) { | |||||
| expired_element_full_ = true; | |||||
| MS_LOG(INFO) << "Running step:" << graph_running_step << "(num:" << graph_running_index_num_ | |||||
| << ") will be used, index swap will wait until the graph completed."; | |||||
| } | |||||
| } | } | ||||
| if (graph_running_index_pos_ != graph_running_index_num_) { | |||||
| *need_swap = true; | |||||
| *need_wait_graph = true; | |||||
| return graph_running_index_[graph_running_index_pos_++]; | |||||
| } | |||||
| return INVALID_INDEX_VALUE; | |||||
| } | } | ||||
| void EmbeddingHashMap::DumpHashMap() { | void EmbeddingHashMap::DumpHashMap() { | ||||
| @@ -66,5 +96,12 @@ void EmbeddingHashMap::DumpHashMap() { | |||||
| } | } | ||||
| MS_LOG(INFO) << "Dump hash map info end."; | MS_LOG(INFO) << "Dump hash map info end."; | ||||
| } | } | ||||
| void EmbeddingHashMap::Reset() { | |||||
| current_batch_start_pos_ = current_pos_; | |||||
| graph_running_index_num_ = 0; | |||||
| graph_running_index_pos_ = 0; | |||||
| expired_element_full_ = false; | |||||
| } | |||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -34,6 +34,7 @@ struct HashMapElement { | |||||
| size_t step_{INVALID_STEP_VALUE}; | size_t step_{INVALID_STEP_VALUE}; | ||||
| bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; } | bool IsEmpty() const { return step_ == INVALID_STEP_VALUE; } | ||||
| bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; } | bool IsExpired(size_t graph_running_step) const { return graph_running_step > step_; } | ||||
| bool IsStep(size_t step) const { return step_ == step; } | |||||
| void set_id(int id) { id_ = id; } | void set_id(int id) { id_ = id; } | ||||
| void set_step(size_t step) { step_ = step; } | void set_step(size_t step) { step_ = step; } | ||||
| }; | }; | ||||
| @@ -41,25 +42,39 @@ struct HashMapElement { | |||||
| // Hash table is held in device, HashMap is used to manage hash table in host. | // Hash table is held in device, HashMap is used to manage hash table in host. | ||||
| class EmbeddingHashMap { | class EmbeddingHashMap { | ||||
| public: | public: | ||||
| EmbeddingHashMap(size_t hash_count, size_t hash_capacity) : hash_count_(hash_count), hash_capacity_(hash_capacity) { | |||||
| EmbeddingHashMap(size_t hash_count, size_t hash_capacity) | |||||
| : hash_count_(hash_count), | |||||
| hash_capacity_(hash_capacity), | |||||
| current_pos_(0), | |||||
| current_batch_start_pos_(0), | |||||
| graph_running_index_num_(0), | |||||
| graph_running_index_pos_(0), | |||||
| expired_element_full_(false) { | |||||
| hash_map_elements_.resize(hash_capacity); | hash_map_elements_.resize(hash_capacity); | ||||
| graph_running_index_ = std::make_unique<int[]>(hash_capacity); | |||||
| } | } | ||||
| virtual ~EmbeddingHashMap() = default; | virtual ~EmbeddingHashMap() = default; | ||||
| int ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step, | int ParseData(const int id, int *swap_out_index, int *swap_out_ids, const size_t data_step, | ||||
| const size_t graph_running_step, size_t *swap_out_size); | |||||
| const size_t graph_running_step, size_t *swap_out_size, bool *need_wait_graph); | |||||
| size_t hash_step(const int hash_index) const { return hash_map_elements_[hash_index].step_; } | size_t hash_step(const int hash_index) const { return hash_map_elements_[hash_index].step_; } | ||||
| void set_hash_step(const int hash_index, const size_t step) { hash_map_elements_[hash_index].set_step(step); } | void set_hash_step(const int hash_index, const size_t step) { hash_map_elements_[hash_index].set_step(step); } | ||||
| const std::unordered_map<int, int> &hash_id_to_index() const { return hash_id_to_index_; } | const std::unordered_map<int, int> &hash_id_to_index() const { return hash_id_to_index_; } | ||||
| size_t hash_capacity() const { return hash_capacity_; } | size_t hash_capacity() const { return hash_capacity_; } | ||||
| void DumpHashMap(); | void DumpHashMap(); | ||||
| void Reset(); | |||||
| private: | private: | ||||
| int Hash(const int id) { return static_cast<int>((0.6180339 * id - std::floor(0.6180339 * id)) * hash_capacity_); } | |||||
| bool NeedSwap() const { return hash_count_ > FloatToSize(hash_capacity_ * 0.9); } | |||||
| int FindInsertionPos(const size_t data_step, const size_t graph_running_step, bool *need_swap, bool *need_wait_graph); | |||||
| size_t hash_count_; | size_t hash_count_; | ||||
| size_t hash_capacity_; | size_t hash_capacity_; | ||||
| std::vector<HashMapElement> hash_map_elements_; | std::vector<HashMapElement> hash_map_elements_; | ||||
| std::unordered_map<int, int> hash_id_to_index_; | std::unordered_map<int, int> hash_id_to_index_; | ||||
| size_t current_pos_; | |||||
| size_t current_batch_start_pos_; | |||||
| size_t graph_running_index_num_; | |||||
| size_t graph_running_index_pos_; | |||||
| std::unique_ptr<int[]> graph_running_index_; | |||||
| bool expired_element_full_; | |||||
| }; | }; | ||||
| } // namespace ps | } // namespace ps | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -369,6 +369,10 @@ bool PsCacheManager::ProcessData() { | |||||
| // Get hash swap in/out index and ids. | // Get hash swap in/out index and ids. | ||||
| RETURN_IF_FALSE(ParseData(batch_ids, batch_ids_len, hash_index.get())); | RETURN_IF_FALSE(ParseData(batch_ids, batch_ids_len, hash_index.get())); | ||||
| DumpStatisticsInfo(); | DumpStatisticsInfo(); | ||||
| if ((device_need_wait_graph_ || host_need_wait_graph_) && (!WaitGraphRun())) { | |||||
| MS_LOG(ERROR) << "Ps cache wait graph finish failed."; | |||||
| return false; | |||||
| } | |||||
| for (const auto &item : hash_tables_) { | for (const auto &item : hash_tables_) { | ||||
| auto key = worker.GetParamKey(item.first); | auto key = worker.GetParamKey(item.first); | ||||
| auto hash_info = item.second; | auto hash_info = item.second; | ||||
| @@ -454,6 +458,20 @@ bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_id | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool PsCacheManager::ResetEmbeddingHashMap() { | |||||
| MS_ERROR_IF_NULL(embedding_device_cache_); | |||||
| const auto &device_hash_map = embedding_device_cache_->device_hash_map_; | |||||
| MS_ERROR_IF_NULL(device_hash_map); | |||||
| MS_ERROR_IF_NULL(embedding_host_cache_); | |||||
| const auto &host_hash_map = embedding_host_cache_->host_hash_map_; | |||||
| MS_ERROR_IF_NULL(host_hash_map); | |||||
| device_hash_map->Reset(); | |||||
| host_hash_map->Reset(); | |||||
| device_need_wait_graph_ = false; | |||||
| host_need_wait_graph_ = false; | |||||
| return true; | |||||
| } | |||||
| bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) { | bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) { | ||||
| MS_ERROR_IF_NULL(batch_ids); | MS_ERROR_IF_NULL(batch_ids); | ||||
| MS_ERROR_IF_NULL(hash_index); | MS_ERROR_IF_NULL(hash_index); | ||||
| @@ -463,6 +481,7 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, | |||||
| MS_LOG(EXCEPTION) << "Data in device memset failed."; | MS_LOG(EXCEPTION) << "Data in device memset failed."; | ||||
| } | } | ||||
| CheckIDInDevice(batch_ids, batch_ids_len, hash_index, in_device.get()); | CheckIDInDevice(batch_ids, batch_ids_len, hash_index, in_device.get()); | ||||
| RETURN_IF_FALSE(ResetEmbeddingHashMap()); | |||||
| for (size_t i = 0; i < batch_ids_len; i++) { | for (size_t i = 0; i < batch_ids_len; i++) { | ||||
| if (in_device[i]) { | if (in_device[i]) { | ||||
| continue; | continue; | ||||
| @@ -529,7 +548,7 @@ bool PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, | |||||
| auto tmp_device_to_host_size = statistics_info_.device_to_host_size_; | auto tmp_device_to_host_size = statistics_info_.device_to_host_size_; | ||||
| while (true) { | while (true) { | ||||
| index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_, | index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_, | ||||
| &(statistics_info_.device_to_host_size_)); | |||||
| &(statistics_info_.device_to_host_size_), &device_need_wait_graph_); | |||||
| if (index == INVALID_INDEX_VALUE) { | if (index == INVALID_INDEX_VALUE) { | ||||
| if (!WaitGraphRun()) { | if (!WaitGraphRun()) { | ||||
| return false; | return false; | ||||
| @@ -570,8 +589,9 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) { | |||||
| MS_ERROR_IF_NULL(server_to_host_index); | MS_ERROR_IF_NULL(server_to_host_index); | ||||
| MS_ERROR_IF_NULL(server_to_host_ids); | MS_ERROR_IF_NULL(server_to_host_ids); | ||||
| while (true) { | while (true) { | ||||
| auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, | |||||
| graph_running_step_, &statistics_info_.host_to_server_size_); | |||||
| auto index = | |||||
| host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, graph_running_step_, | |||||
| &statistics_info_.host_to_server_size_, &host_need_wait_graph_); | |||||
| if (index == INVALID_INDEX_VALUE) { | if (index == INVALID_INDEX_VALUE) { | ||||
| RETURN_IF_FALSE(WaitGraphRun()); | RETURN_IF_FALSE(WaitGraphRun()); | ||||
| continue; | continue; | ||||
| @@ -607,8 +627,9 @@ bool PsCacheManager::ParseHostDataDeviceToHost() { | |||||
| int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); | int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); | ||||
| int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); | int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); | ||||
| while (true) { | while (true) { | ||||
| auto index = host_hash_map->ParseData(swap_device_to_host_id, host_to_server_index, host_to_server_ids, | |||||
| data_step_, graph_running_step_, &statistics_info_.host_to_server_size_); | |||||
| auto index = | |||||
| host_hash_map->ParseData(swap_device_to_host_id, host_to_server_index, host_to_server_ids, data_step_, | |||||
| graph_running_step_, &statistics_info_.host_to_server_size_, &host_need_wait_graph_); | |||||
| if (index == INVALID_INDEX_VALUE) { | if (index == INVALID_INDEX_VALUE) { | ||||
| RETURN_IF_FALSE(WaitGraphRun()); | RETURN_IF_FALSE(WaitGraphRun()); | ||||
| continue; | continue; | ||||
| @@ -173,6 +173,7 @@ class PsCacheManager { | |||||
| bool CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device, | bool CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device, | ||||
| size_t *hash_hit_count); | size_t *hash_hit_count); | ||||
| bool CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device); | bool CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index, bool *in_device); | ||||
| bool ResetEmbeddingHashMap(); | |||||
| bool initialized_ps_cache_{false}; | bool initialized_ps_cache_{false}; | ||||
| std::string channel_name_; | std::string channel_name_; | ||||
| std::mutex channel_mutex_; | std::mutex channel_mutex_; | ||||
| @@ -198,6 +199,8 @@ class PsCacheManager { | |||||
| std::atomic_bool finish_init_parameter_server_{false}; | std::atomic_bool finish_init_parameter_server_{false}; | ||||
| std::atomic_bool running_{false}; | std::atomic_bool running_{false}; | ||||
| bool finish_embedding_table_sync_{false}; | bool finish_embedding_table_sync_{false}; | ||||
| bool device_need_wait_graph_{false}; | |||||
| bool host_need_wait_graph_{false}; | |||||
| }; | }; | ||||
| static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance(); | static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance(); | ||||