From 01e9ca59228730cadb77128cfbdd3cb847ed04ed Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Sat, 19 Dec 2020 10:55:46 +0800 Subject: [PATCH] support ps cache data process thread exit --- .../ccsrc/ps/ps_cache/ps_cache_manager.cc | 149 ++++++++++-------- .../ccsrc/ps/ps_cache/ps_cache_manager.h | 7 +- 2 files changed, 91 insertions(+), 65 deletions(-) diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc index 7b7b30ffe8..3b360bfb53 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc @@ -168,7 +168,7 @@ void PsCacheManager::AddEmbeddingTable() const { } void PsCacheManager::InitParameterServer() { - MS_LOG(INFO) << "Embedding table init begin:" << finish_insert_init_info_; + MS_LOG(INFO) << "PS embedding cache table init begin:" << finish_insert_init_info_; std::unique_lock locker(data_mutex_); insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true || running_ == false; }); if (!running_) { @@ -197,7 +197,20 @@ void PsCacheManager::InitParameterServer() { finish_init_parameter_server_ = true; data_prase_.notify_one(); - MS_LOG(INFO) << "Embedding table init end."; + MS_LOG(INFO) << "PS embedding cache table init end."; +} + +void PsCacheManager::InitDataChannel() { + MS_LOG(INFO) << "PS embedding cache data channel init begin."; + auto channel = channel_name(); + if (channel.empty()) { + std::unique_lock locker(data_mutex_); + data_prase_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; }); + if (!running_) { + return; + } + } + MS_LOG(INFO) << "PS embedding cache data channel init end."; } void PsCacheManager::AllocMemForHashTable() { @@ -270,8 +283,8 @@ bool PsCacheManager::IncreaseStep() { } void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { - if (terminated_) { - MS_LOG(EXCEPTION) << "ps cache data process thread is terminated."; + if (!running_) { + MS_LOG(EXCEPTION) << "PS embedding cache data processing thread isn't running."; } if (graph_step_ >= UINT64_MAX) { MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t."; @@ -279,7 +292,10 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { if (graph_step_ == 0) { MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameter_server_; std::unique_lock locker(data_mutex_); - data_prase_.wait(locker, [this] { return finish_init_parameter_server_ == true; }); + data_prase_.wait(locker, [this] { return ((finish_init_parameter_server_ == true) || (running_ == false)); }); + if (!running_) { + MS_LOG(EXCEPTION) << "PS embedding cache data processing thread isn't running."; + } MS_LOG(INFO) << "Graph running waiting embedding table init end."; } graph_step_++; @@ -300,25 +316,21 @@ void PsCacheManager::DoProcessData(uint32_t device_id, void *context) { } void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) { - embedding_device_cache_->cache_->InitDevice(device_id, context); + MS_LOG(INFO) << "PS embedding cache process data task begin."; running_ = true; - bool ret = true; + embedding_device_cache_->cache_->InitDevice(device_id, context); InitParameterServer(); - while (ret) { - if (!running_) { - break; + InitDataChannel(); + while (running_) { + if (!ProcessData()) { + running_ = false; } - ret = ProcessData(); - } - if (!ret) { - terminated_ = true; } + MS_LOG(INFO) << "PS embedding cache process data task end."; } void PsCacheManager::Finalize() { - if (running_) { - running_ = false; - } + running_ = false; PsDataPrefetch::GetInstance().NotifyFinalize(); insert_init_info_.notify_all(); data_prase_.notify_all(); @@ -331,14 +343,6 @@ bool PsCacheManager::ProcessData() { struct timeval start_time, end_time; const uint64_t kUSecondInSecond = 1000000; (void)gettimeofday(&start_time, nullptr); - auto channel = channel_name(); - if (channel.empty()) { - std::unique_lock locker(data_mutex_); - data_prase_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; }); - if (!running_) { - return false; - } - } auto data = PsDataPrefetch::GetInstance().data(channel_name_); if (data == nullptr) { MS_LOG(INFO) << "No data process, channel name:" << channel_name_; @@ -361,6 +365,7 @@ bool PsCacheManager::ProcessData() { } // Get hash swap in/out index and ids. RETURN_IF_FALSE(ParseData(batch_ids, batch_ids_len, hash_index.get())); + DumpStatisticsInfo(); for (const auto &item : hash_tables_) { auto key = worker.GetParamKey(item.first); auto hash_info = item.second; @@ -389,6 +394,7 @@ bool PsCacheManager::ProcessData() { bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) { MS_ERROR_IF_NULL(batch_ids); MS_ERROR_IF_NULL(hash_index); + statistics_info_.batch_id_count_ = batch_ids_len; for (size_t i = 0; i < batch_ids_len; i++) { bool need_swap_host_to_device = true; bool need_swap_device_to_host = true; @@ -397,10 +403,8 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, hash_index[i] = -1; continue; } - auto index = ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device); - if (index == INVALID_INDEX_VALUE) { - return false; - } + int index = INVALID_INDEX_VALUE; + RETURN_IF_FALSE(ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device, &index)); hash_index[i] = index; if (need_swap_host_to_device) { RETURN_IF_FALSE(ParseHostDataHostToDevice(id)); @@ -409,12 +413,6 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, RETURN_IF_FALSE(ParseHostDataDeviceToHost(id)); } } - // Each 1000 step prints ps cache hit rate. - if (data_step_ % 1000 == 0) { - statistics_info_.batch_id_unique_count_ = statistics_info_.hash_hit_count_ + statistics_info_.host_to_device_size_; - auto hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_; - MS_LOG(INFO) << "Ps cache hit rate: " << hit_rate * 100 << "%."; - } return true; } @@ -430,14 +428,16 @@ bool PsCacheManager::WaitGraphRun() { return true; } -int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device) { - int *device_to_host_index = embedding_device_cache_->device_to_host_index.get(); - int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); - int *host_to_device_index = embedding_device_cache_->host_to_device_index.get(); - int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get(); - +bool PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device, + int *hash_index) { + MS_ERROR_IF_NULL(need_swap_device_to_host); + MS_ERROR_IF_NULL(need_swap_host_to_device); + MS_ERROR_IF_NULL(hash_index); + MS_ERROR_IF_NULL(embedding_device_cache_); auto device_hash_map = embedding_device_cache_->device_hash_map_; - int index = 0; + MS_ERROR_IF_NULL(device_hash_map); + + int index = INVALID_INDEX_VALUE; auto iter = device_hash_map->id_iter(id); if (device_hash_map->IsIdExist(iter)) { *need_swap_device_to_host = false; @@ -448,13 +448,19 @@ int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, b device_hash_map->set_hash_step(index, data_step_); } } else { + int *device_to_host_index = embedding_device_cache_->device_to_host_index.get(); + int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); + int *host_to_device_index = embedding_device_cache_->host_to_device_index.get(); + int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get(); + MS_ERROR_IF_NULL(host_to_device_index); + MS_ERROR_IF_NULL(host_to_device_ids); auto tmp_device_to_host_size = statistics_info_.device_to_host_size_; while (true) { index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_, &(statistics_info_.device_to_host_size_)); if (index == INVALID_INDEX_VALUE) { if (!WaitGraphRun()) { - return INVALID_INDEX_VALUE; + return false; } continue; } @@ -465,23 +471,17 @@ int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, b break; } } - return index; + *hash_index = index; + return true; } bool PsCacheManager::ParseHostDataHostToDevice(size_t id) { - int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); - int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); - int *server_to_host_index = embedding_host_cache_->server_to_host_index.get(); - int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get(); + MS_ERROR_IF_NULL(embedding_host_cache_); int *host_to_device_index = embedding_host_cache_->host_to_device_index.get(); - MS_ERROR_IF_NULL(host_to_server_index); - MS_ERROR_IF_NULL(host_to_server_ids); - MS_ERROR_IF_NULL(server_to_host_index); - MS_ERROR_IF_NULL(server_to_host_ids); MS_ERROR_IF_NULL(host_to_device_index); - auto host_hash_map = embedding_host_cache_->host_hash_map_; MS_ERROR_IF_NULL(host_hash_map); + auto iter = host_hash_map->id_iter(id); if (host_hash_map->IsIdExist(iter)) { auto index = iter->second; @@ -490,6 +490,12 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) { } host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index; } else { + int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); + int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); + int *server_to_host_index = embedding_host_cache_->server_to_host_index.get(); + int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get(); + MS_ERROR_IF_NULL(server_to_host_index); + MS_ERROR_IF_NULL(server_to_host_ids); while (true) { auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, graph_running_step_, &statistics_info_.host_to_server_size_); @@ -507,13 +513,10 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) { } bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) { + MS_ERROR_IF_NULL(embedding_device_cache_); int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); - int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); - int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); int *device_to_host_index = embedding_host_cache_->device_to_host_index.get(); MS_ERROR_IF_NULL(device_to_host_ids); - MS_ERROR_IF_NULL(host_to_server_index); - MS_ERROR_IF_NULL(host_to_server_ids); MS_ERROR_IF_NULL(device_to_host_index); auto host_hash_map = embedding_host_cache_->host_hash_map_; @@ -527,6 +530,8 @@ bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) { } device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index; } else { + int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); + int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); while (true) { auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, graph_running_step_, &statistics_info_.host_to_server_size_); @@ -552,13 +557,13 @@ void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens); if (ret != EOK) { MS_LOG(ERROR) << "LookUpTable task memcpy failed."; - terminated_ = true; + running_ = false; } } else { auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens); if (ret != EOK) { MS_LOG(ERROR) << "LookUpTable task memset failed."; - terminated_ = true; + running_ = false; } } output_addr += outer_dim_size; @@ -592,7 +597,7 @@ bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l for (size_t j = 0; j < i; j++) { threads[j].join(); } - return !terminated_; + return running_; } bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, @@ -615,7 +620,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in auto ret = memcpy_s(hash_table_addr + index * outer_dim_size, lens, insert_data + i * outer_dim_size, lens); if (ret != EOK) { MS_LOG(ERROR) << "Insert hash table task memcpy failed."; - terminated_ = true; + running_ = false; } } } @@ -637,7 +642,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in for (size_t j = 0; j < i; j++) { threads[j].join(); } - return !terminated_; + return running_; } bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) { @@ -862,5 +867,25 @@ void PsCacheManager::DumpHashTables(bool dump_device_tables) const { } } } + +void PsCacheManager::DumpStatisticsInfo(size_t each_print_step) { + // Default each 1000 step prints ps cache hit rate. + if (data_step_ % each_print_step == 0) { + statistics_info_.batch_id_unique_count_ = statistics_info_.hash_hit_count_ + statistics_info_.host_to_device_size_; + auto repeat_rate = SizeToFloat(statistics_info_.batch_id_count_ - statistics_info_.batch_id_unique_count_) / + statistics_info_.batch_id_count_; + auto device_hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_; + auto host_hit_rate = SizeToFloat(statistics_info_.batch_id_unique_count_ - statistics_info_.server_to_host_size_) / + statistics_info_.batch_id_unique_count_; + MS_LOG(INFO) << "PS embedding cache data statistics info(total id num:" << statistics_info_.batch_id_count_ + << ", unique id num:" << statistics_info_.batch_id_unique_count_ + << ", host swap to device num:" << statistics_info_.host_to_device_size_ + << ", device swap to host num:" << statistics_info_.device_to_host_size_ + << ", host swap to server num:" << statistics_info_.host_to_server_size_ + << ", server swap to host num:" << statistics_info_.server_to_host_size_ + << ", data repeat rate:" << repeat_rate * 100 << "%, device cache hit rate:" << device_hit_rate * 100 + << "%, host cache hit rate:" << host_hit_rate * 100 << ")."; + } +} } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h index 99eb6cdb00..bbf1db6519 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h @@ -94,6 +94,7 @@ struct EmbeddingHostCache { }; struct PsCacheStatisticsInfo { + size_t batch_id_count_{0}; size_t batch_id_unique_count_{0}; size_t device_to_host_size_{0}; size_t host_to_device_size_{0}; @@ -126,7 +127,6 @@ class PsCacheManager { bool initialized_ps_cache() const { return initialized_ps_cache_; } void DoProcessData(uint32_t device_id, void *context); void IncreaseGraphStep(const std::string &channel_name); - bool terminated() const { return terminated_; } void Finalize(); void DumpHashTables(bool dump_device_tables = false) const; @@ -140,13 +140,14 @@ class PsCacheManager { std::string channel_name(); void set_channel_name(const std::string channel_name); void InitParameterServer(); + void InitDataChannel(); void AllocMemForHashTable(); void SetLocalIdRank(); void ProcessDataTask(uint32_t device_id, void *context); bool ProcessData(); bool ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index); bool WaitGraphRun(); - int ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device); + bool ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device, int *hash_index); bool ParseHostDataHostToDevice(size_t id); bool ParseHostDataDeviceToHost(size_t id); bool HashSwapDeviceOut(int *swap_out_index, ::ps::SArray *swap_out_data, const HashTableInfo &hash_info); @@ -164,6 +165,7 @@ class PsCacheManager { const int *indices_addr, float *output_addr); bool CheckFinishInsertInitInfo() const; void AddEmbeddingTable() const; + void DumpStatisticsInfo(size_t each_print_step = 1000); bool initialized_ps_cache_{false}; std::string channel_name_; @@ -189,7 +191,6 @@ class PsCacheManager { std::atomic_bool finish_insert_init_info_{false}; std::atomic_bool finish_init_parameter_server_{false}; std::atomic_bool running_{false}; - std::atomic_bool terminated_{false}; }; static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance();