|
|
|
@@ -168,7 +168,7 @@ void PsCacheManager::AddEmbeddingTable() const { |
|
|
|
} |
|
|
|
|
|
|
|
void PsCacheManager::InitParameterServer() { |
|
|
|
MS_LOG(INFO) << "Embedding table init begin:" << finish_insert_init_info_; |
|
|
|
MS_LOG(INFO) << "PS embedding cache table init begin:" << finish_insert_init_info_; |
|
|
|
std::unique_lock<std::mutex> locker(data_mutex_); |
|
|
|
insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true || running_ == false; }); |
|
|
|
if (!running_) { |
|
|
|
@@ -197,7 +197,20 @@ void PsCacheManager::InitParameterServer() { |
|
|
|
|
|
|
|
finish_init_parameter_server_ = true; |
|
|
|
data_prase_.notify_one(); |
|
|
|
MS_LOG(INFO) << "Embedding table init end."; |
|
|
|
MS_LOG(INFO) << "PS embedding cache table init end."; |
|
|
|
} |
|
|
|
|
|
|
|
void PsCacheManager::InitDataChannel() { |
|
|
|
MS_LOG(INFO) << "PS embedding cache data channel init begin."; |
|
|
|
auto channel = channel_name(); |
|
|
|
if (channel.empty()) { |
|
|
|
std::unique_lock<std::mutex> locker(data_mutex_); |
|
|
|
data_prase_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; }); |
|
|
|
if (!running_) { |
|
|
|
return; |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "PS embedding cache data channel init end."; |
|
|
|
} |
|
|
|
|
|
|
|
void PsCacheManager::AllocMemForHashTable() { |
|
|
|
@@ -270,8 +283,8 @@ bool PsCacheManager::IncreaseStep() { |
|
|
|
} |
|
|
|
|
|
|
|
void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { |
|
|
|
if (terminated_) { |
|
|
|
MS_LOG(EXCEPTION) << "ps cache data process thread is terminated."; |
|
|
|
if (!running_) { |
|
|
|
MS_LOG(EXCEPTION) << "PS embedding cache data processing thread isn't running."; |
|
|
|
} |
|
|
|
if (graph_step_ >= UINT64_MAX) { |
|
|
|
MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t."; |
|
|
|
@@ -279,7 +292,10 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { |
|
|
|
if (graph_step_ == 0) { |
|
|
|
MS_LOG(INFO) << "Graph running waiting embedding table init begin:" << finish_init_parameter_server_; |
|
|
|
std::unique_lock<std::mutex> locker(data_mutex_); |
|
|
|
data_prase_.wait(locker, [this] { return finish_init_parameter_server_ == true; }); |
|
|
|
data_prase_.wait(locker, [this] { return ((finish_init_parameter_server_ == true) || (running_ == false)); }); |
|
|
|
if (!running_) { |
|
|
|
MS_LOG(EXCEPTION) << "PS embedding cache data processing thread isn't running."; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "Graph running waiting embedding table init end."; |
|
|
|
} |
|
|
|
graph_step_++; |
|
|
|
@@ -300,25 +316,21 @@ void PsCacheManager::DoProcessData(uint32_t device_id, void *context) { |
|
|
|
} |
|
|
|
|
|
|
|
void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) { |
|
|
|
embedding_device_cache_->cache_->InitDevice(device_id, context); |
|
|
|
MS_LOG(INFO) << "PS embedding cache process data task begin."; |
|
|
|
running_ = true; |
|
|
|
bool ret = true; |
|
|
|
embedding_device_cache_->cache_->InitDevice(device_id, context); |
|
|
|
InitParameterServer(); |
|
|
|
while (ret) { |
|
|
|
if (!running_) { |
|
|
|
break; |
|
|
|
InitDataChannel(); |
|
|
|
while (running_) { |
|
|
|
if (!ProcessData()) { |
|
|
|
running_ = false; |
|
|
|
} |
|
|
|
ret = ProcessData(); |
|
|
|
} |
|
|
|
if (!ret) { |
|
|
|
terminated_ = true; |
|
|
|
} |
|
|
|
MS_LOG(INFO) << "PS embedding cache process data task end."; |
|
|
|
} |
|
|
|
|
|
|
|
void PsCacheManager::Finalize() { |
|
|
|
if (running_) { |
|
|
|
running_ = false; |
|
|
|
} |
|
|
|
running_ = false; |
|
|
|
PsDataPrefetch::GetInstance().NotifyFinalize(); |
|
|
|
insert_init_info_.notify_all(); |
|
|
|
data_prase_.notify_all(); |
|
|
|
@@ -331,14 +343,6 @@ bool PsCacheManager::ProcessData() { |
|
|
|
struct timeval start_time, end_time; |
|
|
|
const uint64_t kUSecondInSecond = 1000000; |
|
|
|
(void)gettimeofday(&start_time, nullptr); |
|
|
|
auto channel = channel_name(); |
|
|
|
if (channel.empty()) { |
|
|
|
std::unique_lock<std::mutex> locker(data_mutex_); |
|
|
|
data_prase_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; }); |
|
|
|
if (!running_) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
auto data = PsDataPrefetch::GetInstance().data(channel_name_); |
|
|
|
if (data == nullptr) { |
|
|
|
MS_LOG(INFO) << "No data process, channel name:" << channel_name_; |
|
|
|
@@ -361,6 +365,7 @@ bool PsCacheManager::ProcessData() { |
|
|
|
} |
|
|
|
// Get hash swap in/out index and ids. |
|
|
|
RETURN_IF_FALSE(ParseData(batch_ids, batch_ids_len, hash_index.get())); |
|
|
|
DumpStatisticsInfo(); |
|
|
|
for (const auto &item : hash_tables_) { |
|
|
|
auto key = worker.GetParamKey(item.first); |
|
|
|
auto hash_info = item.second; |
|
|
|
@@ -389,6 +394,7 @@ bool PsCacheManager::ProcessData() { |
|
|
|
bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) { |
|
|
|
MS_ERROR_IF_NULL(batch_ids); |
|
|
|
MS_ERROR_IF_NULL(hash_index); |
|
|
|
statistics_info_.batch_id_count_ = batch_ids_len; |
|
|
|
for (size_t i = 0; i < batch_ids_len; i++) { |
|
|
|
bool need_swap_host_to_device = true; |
|
|
|
bool need_swap_device_to_host = true; |
|
|
|
@@ -397,10 +403,8 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, |
|
|
|
hash_index[i] = -1; |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto index = ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device); |
|
|
|
if (index == INVALID_INDEX_VALUE) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
int index = INVALID_INDEX_VALUE; |
|
|
|
RETURN_IF_FALSE(ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device, &index)); |
|
|
|
hash_index[i] = index; |
|
|
|
if (need_swap_host_to_device) { |
|
|
|
RETURN_IF_FALSE(ParseHostDataHostToDevice(id)); |
|
|
|
@@ -409,12 +413,6 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, |
|
|
|
RETURN_IF_FALSE(ParseHostDataDeviceToHost(id)); |
|
|
|
} |
|
|
|
} |
|
|
|
// Each 1000 step prints ps cache hit rate. |
|
|
|
if (data_step_ % 1000 == 0) { |
|
|
|
statistics_info_.batch_id_unique_count_ = statistics_info_.hash_hit_count_ + statistics_info_.host_to_device_size_; |
|
|
|
auto hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_; |
|
|
|
MS_LOG(INFO) << "Ps cache hit rate: " << hit_rate * 100 << "%."; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -430,14 +428,16 @@ bool PsCacheManager::WaitGraphRun() { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device) { |
|
|
|
int *device_to_host_index = embedding_device_cache_->device_to_host_index.get(); |
|
|
|
int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); |
|
|
|
int *host_to_device_index = embedding_device_cache_->host_to_device_index.get(); |
|
|
|
int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get(); |
|
|
|
|
|
|
|
bool PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device, |
|
|
|
int *hash_index) { |
|
|
|
MS_ERROR_IF_NULL(need_swap_device_to_host); |
|
|
|
MS_ERROR_IF_NULL(need_swap_host_to_device); |
|
|
|
MS_ERROR_IF_NULL(hash_index); |
|
|
|
MS_ERROR_IF_NULL(embedding_device_cache_); |
|
|
|
auto device_hash_map = embedding_device_cache_->device_hash_map_; |
|
|
|
int index = 0; |
|
|
|
MS_ERROR_IF_NULL(device_hash_map); |
|
|
|
|
|
|
|
int index = INVALID_INDEX_VALUE; |
|
|
|
auto iter = device_hash_map->id_iter(id); |
|
|
|
if (device_hash_map->IsIdExist(iter)) { |
|
|
|
*need_swap_device_to_host = false; |
|
|
|
@@ -448,13 +448,19 @@ int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, b |
|
|
|
device_hash_map->set_hash_step(index, data_step_); |
|
|
|
} |
|
|
|
} else { |
|
|
|
int *device_to_host_index = embedding_device_cache_->device_to_host_index.get(); |
|
|
|
int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); |
|
|
|
int *host_to_device_index = embedding_device_cache_->host_to_device_index.get(); |
|
|
|
int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get(); |
|
|
|
MS_ERROR_IF_NULL(host_to_device_index); |
|
|
|
MS_ERROR_IF_NULL(host_to_device_ids); |
|
|
|
auto tmp_device_to_host_size = statistics_info_.device_to_host_size_; |
|
|
|
while (true) { |
|
|
|
index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_, |
|
|
|
&(statistics_info_.device_to_host_size_)); |
|
|
|
if (index == INVALID_INDEX_VALUE) { |
|
|
|
if (!WaitGraphRun()) { |
|
|
|
return INVALID_INDEX_VALUE; |
|
|
|
return false; |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
@@ -465,23 +471,17 @@ int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, b |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
return index; |
|
|
|
*hash_index = index; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool PsCacheManager::ParseHostDataHostToDevice(size_t id) { |
|
|
|
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); |
|
|
|
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); |
|
|
|
int *server_to_host_index = embedding_host_cache_->server_to_host_index.get(); |
|
|
|
int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get(); |
|
|
|
MS_ERROR_IF_NULL(embedding_host_cache_); |
|
|
|
int *host_to_device_index = embedding_host_cache_->host_to_device_index.get(); |
|
|
|
MS_ERROR_IF_NULL(host_to_server_index); |
|
|
|
MS_ERROR_IF_NULL(host_to_server_ids); |
|
|
|
MS_ERROR_IF_NULL(server_to_host_index); |
|
|
|
MS_ERROR_IF_NULL(server_to_host_ids); |
|
|
|
MS_ERROR_IF_NULL(host_to_device_index); |
|
|
|
|
|
|
|
auto host_hash_map = embedding_host_cache_->host_hash_map_; |
|
|
|
MS_ERROR_IF_NULL(host_hash_map); |
|
|
|
|
|
|
|
auto iter = host_hash_map->id_iter(id); |
|
|
|
if (host_hash_map->IsIdExist(iter)) { |
|
|
|
auto index = iter->second; |
|
|
|
@@ -490,6 +490,12 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) { |
|
|
|
} |
|
|
|
host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index; |
|
|
|
} else { |
|
|
|
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); |
|
|
|
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); |
|
|
|
int *server_to_host_index = embedding_host_cache_->server_to_host_index.get(); |
|
|
|
int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get(); |
|
|
|
MS_ERROR_IF_NULL(server_to_host_index); |
|
|
|
MS_ERROR_IF_NULL(server_to_host_ids); |
|
|
|
while (true) { |
|
|
|
auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, |
|
|
|
graph_running_step_, &statistics_info_.host_to_server_size_); |
|
|
|
@@ -507,13 +513,10 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) { |
|
|
|
} |
|
|
|
|
|
|
|
bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) { |
|
|
|
MS_ERROR_IF_NULL(embedding_device_cache_); |
|
|
|
int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); |
|
|
|
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); |
|
|
|
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); |
|
|
|
int *device_to_host_index = embedding_host_cache_->device_to_host_index.get(); |
|
|
|
MS_ERROR_IF_NULL(device_to_host_ids); |
|
|
|
MS_ERROR_IF_NULL(host_to_server_index); |
|
|
|
MS_ERROR_IF_NULL(host_to_server_ids); |
|
|
|
MS_ERROR_IF_NULL(device_to_host_index); |
|
|
|
|
|
|
|
auto host_hash_map = embedding_host_cache_->host_hash_map_; |
|
|
|
@@ -527,6 +530,8 @@ bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) { |
|
|
|
} |
|
|
|
device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index; |
|
|
|
} else { |
|
|
|
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); |
|
|
|
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); |
|
|
|
while (true) { |
|
|
|
auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, |
|
|
|
graph_running_step_, &statistics_info_.host_to_server_size_); |
|
|
|
@@ -552,13 +557,13 @@ void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, |
|
|
|
auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens); |
|
|
|
if (ret != EOK) { |
|
|
|
MS_LOG(ERROR) << "LookUpTable task memcpy failed."; |
|
|
|
terminated_ = true; |
|
|
|
running_ = false; |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens); |
|
|
|
if (ret != EOK) { |
|
|
|
MS_LOG(ERROR) << "LookUpTable task memset failed."; |
|
|
|
terminated_ = true; |
|
|
|
running_ = false; |
|
|
|
} |
|
|
|
} |
|
|
|
output_addr += outer_dim_size; |
|
|
|
@@ -592,7 +597,7 @@ bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l |
|
|
|
for (size_t j = 0; j < i; j++) { |
|
|
|
threads[j].join(); |
|
|
|
} |
|
|
|
return !terminated_; |
|
|
|
return running_; |
|
|
|
} |
|
|
|
|
|
|
|
bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, |
|
|
|
@@ -615,7 +620,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in |
|
|
|
auto ret = memcpy_s(hash_table_addr + index * outer_dim_size, lens, insert_data + i * outer_dim_size, lens); |
|
|
|
if (ret != EOK) { |
|
|
|
MS_LOG(ERROR) << "Insert hash table task memcpy failed."; |
|
|
|
terminated_ = true; |
|
|
|
running_ = false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -637,7 +642,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in |
|
|
|
for (size_t j = 0; j < i; j++) { |
|
|
|
threads[j].join(); |
|
|
|
} |
|
|
|
return !terminated_; |
|
|
|
return running_; |
|
|
|
} |
|
|
|
|
|
|
|
bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) { |
|
|
|
@@ -862,5 +867,25 @@ void PsCacheManager::DumpHashTables(bool dump_device_tables) const { |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void PsCacheManager::DumpStatisticsInfo(size_t each_print_step) { |
|
|
|
// Default each 1000 step prints ps cache hit rate. |
|
|
|
if (data_step_ % each_print_step == 0) { |
|
|
|
statistics_info_.batch_id_unique_count_ = statistics_info_.hash_hit_count_ + statistics_info_.host_to_device_size_; |
|
|
|
auto repeat_rate = SizeToFloat(statistics_info_.batch_id_count_ - statistics_info_.batch_id_unique_count_) / |
|
|
|
statistics_info_.batch_id_count_; |
|
|
|
auto device_hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_; |
|
|
|
auto host_hit_rate = SizeToFloat(statistics_info_.batch_id_unique_count_ - statistics_info_.server_to_host_size_) / |
|
|
|
statistics_info_.batch_id_unique_count_; |
|
|
|
MS_LOG(INFO) << "PS embedding cache data statistics info(total id num:" << statistics_info_.batch_id_count_ |
|
|
|
<< ", unique id num:" << statistics_info_.batch_id_unique_count_ |
|
|
|
<< ", host swap to device num:" << statistics_info_.host_to_device_size_ |
|
|
|
<< ", device swap to host num:" << statistics_info_.device_to_host_size_ |
|
|
|
<< ", host swap to server num:" << statistics_info_.host_to_server_size_ |
|
|
|
<< ", server swap to host num:" << statistics_info_.server_to_host_size_ |
|
|
|
<< ", data repeat rate:" << repeat_rate * 100 << "%, device cache hit rate:" << device_hit_rate * 100 |
|
|
|
<< "%, host cache hit rate:" << host_hit_rate * 100 << ")."; |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace ps |
|
|
|
} // namespace mindspore |