|
|
|
@@ -394,11 +394,79 @@ bool PsCacheManager::ProcessData() { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool PsCacheManager::CheckIDInDeviceTask(const int *batch_ids, const size_t batch_ids_len, int *hash_index, |
|
|
|
bool *in_device, size_t *hash_hit_count) { |
|
|
|
MS_ERROR_IF_NULL(batch_ids); |
|
|
|
MS_ERROR_IF_NULL(hash_index); |
|
|
|
MS_ERROR_IF_NULL(in_device); |
|
|
|
MS_ERROR_IF_NULL(hash_hit_count); |
|
|
|
MS_ERROR_IF_NULL(embedding_device_cache_); |
|
|
|
auto &device_hash_map = embedding_device_cache_->device_hash_map_; |
|
|
|
MS_ERROR_IF_NULL(device_hash_map); |
|
|
|
const auto &hash_id_to_index = device_hash_map->hash_id_to_index(); |
|
|
|
|
|
|
|
for (size_t i = 0; i < batch_ids_len; ++i) { |
|
|
|
auto iter = hash_id_to_index.find(batch_ids[i]); |
|
|
|
if (iter != hash_id_to_index.end()) { |
|
|
|
hash_index[i] = iter->second; |
|
|
|
if (device_hash_map->hash_step(iter->second) != data_step_) { |
|
|
|
++(*hash_hit_count); |
|
|
|
device_hash_map->set_hash_step(iter->second, data_step_); |
|
|
|
} |
|
|
|
in_device[i] = true; |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool PsCacheManager::CheckIDInDevice(const int *batch_ids, const size_t batch_ids_len, int *hash_index, |
|
|
|
bool *in_device) { |
|
|
|
MS_ERROR_IF_NULL(batch_ids); |
|
|
|
MS_ERROR_IF_NULL(hash_index); |
|
|
|
MS_ERROR_IF_NULL(in_device); |
|
|
|
|
|
|
|
size_t thread_num = batch_ids_len / kMinIdsPerThread + 1; |
|
|
|
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; |
|
|
|
std::thread threads[kMaxThreadNum]; |
|
|
|
size_t hash_hit_count[kMaxThreadNum] = {0}; |
|
|
|
size_t i = 0; |
|
|
|
size_t task_offset = 0; |
|
|
|
|
|
|
|
for (; i < thread_num; ++i) { |
|
|
|
if (task_offset >= batch_ids_len) { |
|
|
|
break; |
|
|
|
} |
|
|
|
size_t task_proc_lens = batch_ids_len / thread_num + (i < (batch_ids_len % thread_num) ? 1 : 0); |
|
|
|
threads[i] = std::thread(&PsCacheManager::CheckIDInDeviceTask, this, batch_ids + task_offset, task_proc_lens, |
|
|
|
hash_index + task_offset, in_device + task_offset, hash_hit_count + i); |
|
|
|
task_offset += task_proc_lens; |
|
|
|
} |
|
|
|
if (task_offset != batch_ids_len) { |
|
|
|
MS_LOG(WARNING) << "Ps cache check id in device inadequate, total:" << batch_ids_len << " checked:" << task_offset; |
|
|
|
} |
|
|
|
|
|
|
|
for (size_t j = 0; j < i; j++) { |
|
|
|
threads[j].join(); |
|
|
|
} |
|
|
|
for (size_t j = 0; j < i; j++) { |
|
|
|
statistics_info_.hash_hit_count_ += hash_hit_count[j]; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) { |
|
|
|
MS_ERROR_IF_NULL(batch_ids); |
|
|
|
MS_ERROR_IF_NULL(hash_index); |
|
|
|
statistics_info_.batch_id_count_ = batch_ids_len; |
|
|
|
std::unique_ptr<bool[]> in_device(new bool[batch_ids_len]); |
|
|
|
if (memset_s(in_device.get(), batch_ids_len * sizeof(bool), 0, batch_ids_len * sizeof(bool))) { |
|
|
|
MS_LOG(EXCEPTION) << "Data in device memset failed."; |
|
|
|
} |
|
|
|
CheckIDInDevice(batch_ids, batch_ids_len, hash_index, in_device.get()); |
|
|
|
for (size_t i = 0; i < batch_ids_len; i++) { |
|
|
|
if (in_device[i]) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
bool need_swap_host_to_device = true; |
|
|
|
bool need_swap_device_to_host = true; |
|
|
|
auto id = batch_ids[i]; |
|
|
|
@@ -585,10 +653,10 @@ bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l |
|
|
|
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; |
|
|
|
std::thread threads[kMaxThreadNum]; |
|
|
|
size_t task_proc_lens = (indices_lens + thread_num - 1) / thread_num; |
|
|
|
size_t i; |
|
|
|
size_t i = 0; |
|
|
|
size_t task_offset = 0; |
|
|
|
MS_LOG(DEBUG) << "Indices lens: " << indices_lens << ", one task proc lens:" << task_proc_lens; |
|
|
|
for (i = 0; i < thread_num; i++) { |
|
|
|
for (; i < thread_num; i++) { |
|
|
|
if (task_offset >= indices_lens) { |
|
|
|
break; |
|
|
|
} |
|
|
|
@@ -613,7 +681,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in |
|
|
|
thread_num = thread_num > kMaxThreadNum ? kMaxThreadNum : thread_num; |
|
|
|
std::thread threads[kMaxThreadNum]; |
|
|
|
size_t task_proc_lens = (insert_indices_size + thread_num - 1) / thread_num; |
|
|
|
size_t i; |
|
|
|
size_t i = 0; |
|
|
|
size_t task_offset = 0; |
|
|
|
|
|
|
|
auto insert_hash_table_task = [this](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size, |
|
|
|
@@ -632,7 +700,7 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
for (i = 0; i < thread_num; i++) { |
|
|
|
for (; i < thread_num; i++) { |
|
|
|
if (task_offset >= insert_indices_size) { |
|
|
|
break; |
|
|
|
} |
|
|
|
|