From: @limingqi107 Reviewed-by: @cristoval,@zhoufeng54 Signed-off-by: @cristovaltags/v1.1.0
| @@ -331,12 +331,7 @@ void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) { | |||||
| void PsCacheManager::Finalize() { | void PsCacheManager::Finalize() { | ||||
| if (running_) { | if (running_) { | ||||
| if (!SyncHostEmbeddingTable()) { | |||||
| MS_LOG(ERROR) << "SyncHostEmbeddingTable failed."; | |||||
| } | |||||
| if (!SyncDeviceEmbeddingTable()) { | |||||
| MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed."; | |||||
| } | |||||
| SyncEmbeddingTable(); | |||||
| } | } | ||||
| running_ = false; | running_ = false; | ||||
| PsDataPrefetch::GetInstance().NotifyFinalize(); | PsDataPrefetch::GetInstance().NotifyFinalize(); | ||||
| @@ -846,6 +841,19 @@ bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_da | |||||
| return true; | return true; | ||||
| } | } | ||||
| void PsCacheManager::SyncEmbeddingTable() { | |||||
| if (finish_embedding_table_sync_) { | |||||
| return; | |||||
| } | |||||
| if (!SyncHostEmbeddingTable()) { | |||||
| MS_LOG(ERROR) << "SyncHostEmbeddingTable failed."; | |||||
| } | |||||
| if (!SyncDeviceEmbeddingTable()) { | |||||
| MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed."; | |||||
| } | |||||
| finish_embedding_table_sync_ = true; | |||||
| } | |||||
| bool PsCacheManager::SyncHostEmbeddingTable() { | bool PsCacheManager::SyncHostEmbeddingTable() { | ||||
| MS_ERROR_IF_NULL(embedding_host_cache_); | MS_ERROR_IF_NULL(embedding_host_cache_); | ||||
| const auto &hash_id_to_index = embedding_host_cache_->host_hash_map_->hash_id_to_index(); | const auto &hash_id_to_index = embedding_host_cache_->host_hash_map_->hash_id_to_index(); | ||||
| @@ -127,6 +127,7 @@ class PsCacheManager { | |||||
| bool initialized_ps_cache() const { return initialized_ps_cache_; } | bool initialized_ps_cache() const { return initialized_ps_cache_; } | ||||
| void DoProcessData(uint32_t device_id, void *context); | void DoProcessData(uint32_t device_id, void *context); | ||||
| void IncreaseGraphStep(const std::string &channel_name); | void IncreaseGraphStep(const std::string &channel_name); | ||||
| void SyncEmbeddingTable(); | |||||
| void Finalize(); | void Finalize(); | ||||
| void DumpHashTables(bool dump_device_tables = false) const; | void DumpHashTables(bool dump_device_tables = false) const; | ||||
| @@ -193,6 +194,7 @@ class PsCacheManager { | |||||
| std::atomic_bool finish_insert_init_info_{false}; | std::atomic_bool finish_insert_init_info_{false}; | ||||
| std::atomic_bool finish_init_parameter_server_{false}; | std::atomic_bool finish_init_parameter_server_{false}; | ||||
| std::atomic_bool running_{false}; | std::atomic_bool running_{false}; | ||||
| bool finish_embedding_table_sync_{false}; | |||||
| }; | }; | ||||
| static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance(); | static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance(); | ||||
| @@ -16,10 +16,16 @@ | |||||
| #include "runtime/device/kernel_runtime_manager.h" | #include "runtime/device/kernel_runtime_manager.h" | ||||
| #include "utils/log_adapter.h" | #include "utils/log_adapter.h" | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| #include "ps/ps_cache/ps_cache_manager.h" | |||||
| #endif | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace device { | namespace device { | ||||
| void KernelRuntimeManager::ClearRuntimeResource() { | void KernelRuntimeManager::ClearRuntimeResource() { | ||||
| #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) | |||||
| ps::ps_cache_instance.SyncEmbeddingTable(); | |||||
| #endif | |||||
| std::lock_guard<std::mutex> guard(lock_); | std::lock_guard<std::mutex> guard(lock_); | ||||
| for (auto &iter : runtime_map_) { | for (auto &iter : runtime_map_) { | ||||
| MS_LOG(INFO) << "Release device " << iter.first; | MS_LOG(INFO) << "Release device " << iter.first; | ||||