|
|
|
@@ -73,6 +73,8 @@ void PsCacheManager::InsertWeightInitInfo(const std::string ¶m_name, size_t |
|
|
|
if (hash_table_info.param_init_info_.param_type_ != kUnKnown) { |
|
|
|
return; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Insert embedding table init info:" << param_name << ", global seed:" << global_seed |
|
|
|
<< ", op seed:" << op_seed; |
|
|
|
hash_table_info.param_init_info_.param_type_ = kWeight; |
|
|
|
hash_table_info.param_init_info_.global_seed_ = global_seed; |
|
|
|
hash_table_info.param_init_info_.op_seed_ = op_seed; |
|
|
|
@@ -91,6 +93,7 @@ void PsCacheManager::InsertAccumuInitInfo(const std::string ¶m_name, float i |
|
|
|
if (hash_table_info.param_init_info_.param_type_ != kUnKnown) { |
|
|
|
return; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Insert accumulation init info:" << param_name << ", init value:" << init_val; |
|
|
|
hash_table_info.param_init_info_.param_type_ = kAccumulation; |
|
|
|
hash_table_info.param_init_info_.init_val_ = init_val; |
|
|
|
if (CheckFinishInsertInitInfo()) { |
|
|
|
@@ -107,6 +110,7 @@ bool PsCacheManager::CheckFinishInsertInitInfo() const { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Finish inserting embedding table init info."; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -141,6 +145,7 @@ void PsCacheManager::Initialize() { |
|
|
|
AddEmbeddingTable(); |
|
|
|
AllocMemForHashTable(); |
|
|
|
SetLocalIdRank(); |
|
|
|
DumpHashTables(); |
|
|
|
initialized_ps_cache_ = true; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -155,6 +160,7 @@ void PsCacheManager::AddEmbeddingTable() const { |
|
|
|
} |
|
|
|
|
|
|
|
void PsCacheManager::InitParameterServer() { |
|
|
|
MS_LOG(INFO) << "Embedding table init begin:" << finish_insert_init_info_; |
|
|
|
std::unique_lock<std::mutex> locker(data_mutex_); |
|
|
|
insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true; }); |
|
|
|
|
|
|
|
@@ -181,6 +187,7 @@ void PsCacheManager::InitParameterServer() { |
|
|
|
|
|
|
|
finish_init_parameter_server_ = true; |
|
|
|
data_prase_.notify_one(); |
|
|
|
MS_LOG(INFO) << "Embedding table init end."; |
|
|
|
} |
|
|
|
|
|
|
|
void PsCacheManager::AllocMemForHashTable() { |
|
|
|
@@ -237,10 +244,14 @@ void PsCacheManager::set_channel_name(const std::string channel_name) { |
|
|
|
|
|
|
|
void PsCacheManager::IncreaseStep() { |
|
|
|
if (data_step_ >= UINT64_MAX) { |
|
|
|
MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") << will exceed the maximum value of uint64_t."; |
|
|
|
MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") << will exceed the maximum value of uint64_t."; |
|
|
|
} |
|
|
|
data_step_++; |
|
|
|
set_current_graph_step(); |
|
|
|
if (graph_running_step_ > data_step_) { |
|
|
|
MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") << exceed the data step (" |
|
|
|
<< data_step_ << ")."; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { |
|
|
|
@@ -248,8 +259,10 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { |
|
|
|
MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") << will exceed the maximum value of uint64_t."; |
|
|
|
} |
|
|
|
if (graph_step_ == 0) { |
|
|
|
MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameter_server_; |
|
|
|
std::unique_lock<std::mutex> locker(data_mutex_); |
|
|
|
data_prase_.wait(locker, [this] { return finish_init_parameter_server_ == true; }); |
|
|
|
MS_LOG(INFO) << "Graph running waiting embedding table init end."; |
|
|
|
} |
|
|
|
graph_step_++; |
|
|
|
set_channel_name(channel_name); |
|
|
|
@@ -755,29 +768,35 @@ void PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray<float> &swap_out_da |
|
|
|
worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); |
|
|
|
} |
|
|
|
|
|
|
|
void PsCacheManager::DumpHashTables() const { |
|
|
|
void PsCacheManager::DumpHashTables(bool dump_device_tables) const { |
|
|
|
MS_EXCEPTION_IF_NULL(embedding_device_cache_); |
|
|
|
MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); |
|
|
|
for (const auto &item : hash_tables_) { |
|
|
|
const auto ¶m_name = item.first; |
|
|
|
size_t cache_vocab_size = item.second.cache_vocab_size; |
|
|
|
size_t host_cache_vocab_size = item.second.host_cache_vocab_size; |
|
|
|
size_t embedding_size = item.second.embedding_size; |
|
|
|
size_t vocab_size = item.second.vocab_size; |
|
|
|
MS_LOG(INFO) << "Dump hash tables: " << param_name << " || " << cache_vocab_size << " || " << embedding_size |
|
|
|
<< " || " << vocab_size << " || " << reinterpret_cast<void *>(item.second.device_address.addr) |
|
|
|
<< " || " << reinterpret_cast<void *>(item.second.host_address.get()); |
|
|
|
float *output = new float[item.second.device_address.size / 4]; |
|
|
|
embedding_device_cache_->cache_->CopyDeviceMemToHost(output, item.second.device_address.addr, |
|
|
|
item.second.device_address.size); |
|
|
|
embedding_device_cache_->cache_->SynchronizeStream(); |
|
|
|
for (size_t i = 0; i < cache_vocab_size; i++) { |
|
|
|
for (size_t j = 0; j < embedding_size; j++) { |
|
|
|
std::cout << output[i * embedding_size + j] << " "; |
|
|
|
MS_LOG(INFO) << "Hash table info:" |
|
|
|
<< " embedding table name:" << param_name << ", vocab size:" << vocab_size |
|
|
|
<< ", embedding size:" << embedding_size << ", device cache size:" << cache_vocab_size |
|
|
|
<< ", host cache size:" << host_cache_vocab_size |
|
|
|
<< ", device cache address:" << reinterpret_cast<void *>(item.second.device_address.addr) |
|
|
|
<< ", host cache address:" << reinterpret_cast<void *>(item.second.host_address.get()); |
|
|
|
if (dump_device_tables) { |
|
|
|
float *output = new float[item.second.device_address.size / 4]; |
|
|
|
embedding_device_cache_->cache_->CopyDeviceMemToHost(output, item.second.device_address.addr, |
|
|
|
item.second.device_address.size); |
|
|
|
embedding_device_cache_->cache_->SynchronizeStream(); |
|
|
|
for (size_t i = 0; i < cache_vocab_size; i++) { |
|
|
|
for (size_t j = 0; j < embedding_size; j++) { |
|
|
|
std::cout << output[i * embedding_size + j] << " "; |
|
|
|
} |
|
|
|
std::cout << std::endl; |
|
|
|
} |
|
|
|
std::cout << std::endl; |
|
|
|
delete[] output; |
|
|
|
} |
|
|
|
std::cout << std::endl; |
|
|
|
delete[] output; |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace ps |
|
|
|
|