|
|
|
@@ -331,12 +331,7 @@ void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) { |
|
|
|
|
|
|
|
void PsCacheManager::Finalize() { |
|
|
|
if (running_) { |
|
|
|
if (!SyncHostEmbeddingTable()) { |
|
|
|
MS_LOG(ERROR) << "SyncHostEmbeddingTable failed."; |
|
|
|
} |
|
|
|
if (!SyncDeviceEmbeddingTable()) { |
|
|
|
MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed."; |
|
|
|
} |
|
|
|
SyncEmbeddingTable(); |
|
|
|
} |
|
|
|
running_ = false; |
|
|
|
PsDataPrefetch::GetInstance().NotifyFinalize(); |
|
|
|
@@ -846,6 +841,19 @@ bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_da |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void PsCacheManager::SyncEmbeddingTable() { |
|
|
|
if (finish_embedding_table_sync_) { |
|
|
|
return; |
|
|
|
} |
|
|
|
if (!SyncHostEmbeddingTable()) { |
|
|
|
MS_LOG(ERROR) << "SyncHostEmbeddingTable failed."; |
|
|
|
} |
|
|
|
if (!SyncDeviceEmbeddingTable()) { |
|
|
|
MS_LOG(ERROR) << "SyncDeviceEmbeddingTable failed."; |
|
|
|
} |
|
|
|
finish_embedding_table_sync_ = true; |
|
|
|
} |
|
|
|
|
|
|
|
bool PsCacheManager::SyncHostEmbeddingTable() { |
|
|
|
MS_ERROR_IF_NULL(embedding_host_cache_); |
|
|
|
const auto &hash_id_to_index = embedding_host_cache_->host_hash_map_->hash_id_to_index(); |
|
|
|
|