|
|
|
@@ -222,9 +222,8 @@ void PsCacheManager::AllocMemForHashTable() { |
|
|
|
device_address.addr = addr; |
|
|
|
|
|
|
|
auto &host_address = item.second.host_address; |
|
|
|
auto host_address_ptr = new float[host_vocab_cache_size_ * embedding_size]; |
|
|
|
MS_EXCEPTION_IF_NULL(host_address_ptr); |
|
|
|
host_address = std::shared_ptr<float[]>(host_address_ptr, std::default_delete<float[]>()); |
|
|
|
host_address = |
|
|
|
std::shared_ptr<float[]>(new float[host_vocab_cache_size_ * embedding_size], std::default_delete<float[]>()); |
|
|
|
MS_EXCEPTION_IF_NULL(host_address); |
|
|
|
|
|
|
|
max_embedding_size = (embedding_size > max_embedding_size) ? embedding_size : max_embedding_size; |
|
|
|
@@ -387,8 +386,9 @@ bool PsCacheManager::ProcessData() { |
|
|
|
RETURN_IF_FALSE(HashSwapServerToHost(key, hash_info)); |
|
|
|
RETURN_IF_FALSE(HashSwapHostToDevice(hash_info)); |
|
|
|
} |
|
|
|
size_t dest_len = data_size; |
|
|
|
// Replace the batch_ids by hash index for getNext-op getting hash index as input. |
|
|
|
if (memcpy_s(data, data_size, hash_index.get(), data_size) != EOK) { |
|
|
|
if (memcpy_s(data, dest_len, hash_index.get(), data_size) != EOK) { |
|
|
|
MS_LOG(ERROR) << "Process data memcpy failed."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
@@ -727,11 +727,13 @@ bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in |
|
|
|
auto insert_hash_table_task = [this](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size, |
|
|
|
const int *insert_indices, const float *insert_data, float *hash_table_addr) { |
|
|
|
auto type_size = sizeof(float); |
|
|
|
size_t lens = outer_dim_size * type_size; |
|
|
|
size_t copy_len = outer_dim_size * type_size; |
|
|
|
size_t dest_len = copy_len; |
|
|
|
for (size_t i = 0; i < insert_indices_size; ++i) { |
|
|
|
int index = insert_indices[i]; |
|
|
|
if (index >= 0 && index < SizeToInt(first_dim_size)) { |
|
|
|
auto ret = memcpy_s(hash_table_addr + index * outer_dim_size, lens, insert_data + i * outer_dim_size, lens); |
|
|
|
auto ret = |
|
|
|
memcpy_s(hash_table_addr + index * outer_dim_size, dest_len, insert_data + i * outer_dim_size, copy_len); |
|
|
|
if (ret != EOK) { |
|
|
|
MS_LOG(ERROR) << "Insert hash table task memcpy failed."; |
|
|
|
running_ = false; |
|
|
|
@@ -836,8 +838,9 @@ bool PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_ |
|
|
|
RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index, |
|
|
|
swap_out_data.data())); |
|
|
|
|
|
|
|
auto copy_len = swap_indices_size * sizeof(int); |
|
|
|
auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids, copy_len); |
|
|
|
size_t copy_len = swap_indices_size * sizeof(int); |
|
|
|
size_t dest_len = copy_len; |
|
|
|
auto ret = memcpy_s(lookup_ids.data(), dest_len, host_to_server_ids, copy_len); |
|
|
|
if (ret != EOK) { |
|
|
|
MS_LOG(ERROR) << "Lookup id memcpy failed."; |
|
|
|
return false; |
|
|
|
@@ -858,8 +861,9 @@ bool PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_ |
|
|
|
auto embedding_size = hash_info.embedding_size; |
|
|
|
std::vector<float> lookup_result(swap_indices_size * embedding_size, 0); |
|
|
|
std::vector<int> lookup_ids(swap_indices_size, 0); |
|
|
|
auto copy_len = swap_indices_size * sizeof(int); |
|
|
|
auto ret = memcpy_s(lookup_ids.data(), copy_len, server_to_host_ids, copy_len); |
|
|
|
size_t copy_len = swap_indices_size * sizeof(int); |
|
|
|
size_t dest_len = copy_len; |
|
|
|
auto ret = memcpy_s(lookup_ids.data(), dest_len, server_to_host_ids, copy_len); |
|
|
|
if (ret != EOK) { |
|
|
|
MS_LOG(ERROR) << "Lookup id memcpy failed."; |
|
|
|
return false; |
|
|
|
@@ -912,8 +916,9 @@ bool PsCacheManager::HashSwapDeviceIn(const int *swap_in_ids, const int *swap_in |
|
|
|
// Get id embs by swap_in_ids in host(Pipeline with hash swap-out in device). |
|
|
|
std::vector<float> lookup_result(swap_in_ids_size * embedding_size, 0); |
|
|
|
std::vector<int> lookup_ids(swap_in_ids_size, 0); |
|
|
|
auto copy_len = swap_in_ids_size * sizeof(int); |
|
|
|
auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_in_ids, copy_len); |
|
|
|
size_t copy_len = swap_in_ids_size * sizeof(int); |
|
|
|
size_t dest_len = copy_len; |
|
|
|
auto ret = memcpy_s(lookup_ids.data(), dest_len, swap_in_ids, copy_len); |
|
|
|
if (ret != EOK) { |
|
|
|
MS_LOG(ERROR) << "Lookup id memcpy failed."; |
|
|
|
return false; |
|
|
|
@@ -940,8 +945,9 @@ bool PsCacheManager::UpdataEmbeddingTable(const std::vector<float> &swap_out_dat |
|
|
|
return true; |
|
|
|
} |
|
|
|
std::vector<int> lookup_ids(swap_out_ids_size, 0); |
|
|
|
auto copy_len = swap_out_ids_size * sizeof(int); |
|
|
|
auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_out_ids, copy_len); |
|
|
|
size_t copy_len = swap_out_ids_size * sizeof(int); |
|
|
|
size_t dest_len = copy_len; |
|
|
|
auto ret = memcpy_s(lookup_ids.data(), dest_len, swap_out_ids, copy_len); |
|
|
|
if (ret != EOK) { |
|
|
|
MS_LOG(ERROR) << "Lookup id memcpy failed."; |
|
|
|
return false; |
|
|
|
@@ -1000,8 +1006,9 @@ bool PsCacheManager::SyncHostEmbeddingTable() { |
|
|
|
RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_lens, host_hash_table_addr, |
|
|
|
host_to_server_indices_ptr.get(), swap_out_data.data())); |
|
|
|
|
|
|
|
auto copy_len = swap_indices_lens * sizeof(int); |
|
|
|
auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids_ptr.get(), copy_len); |
|
|
|
size_t copy_len = swap_indices_lens * sizeof(int); |
|
|
|
size_t dest_len = copy_len; |
|
|
|
auto ret = memcpy_s(lookup_ids.data(), dest_len, host_to_server_ids_ptr.get(), copy_len); |
|
|
|
if (ret != EOK) { |
|
|
|
MS_LOG(ERROR) << "Lookup id memcpy failed."; |
|
|
|
return false; |
|
|
|
@@ -1052,8 +1059,9 @@ bool PsCacheManager::SyncDeviceEmbeddingTable() { |
|
|
|
RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_lens, device_hash_table_addr_tmp.get(), |
|
|
|
device_to_server_indices_ptr.get(), swap_out_data.data())); |
|
|
|
|
|
|
|
auto copy_len = swap_indices_lens * sizeof(int); |
|
|
|
auto ret = memcpy_s(lookup_ids.data(), copy_len, device_to_server_ids_ptr.get(), copy_len); |
|
|
|
size_t copy_len = swap_indices_lens * sizeof(int); |
|
|
|
size_t dest_len = copy_len; |
|
|
|
auto ret = memcpy_s(lookup_ids.data(), dest_len, device_to_server_ids_ptr.get(), copy_len); |
|
|
|
if (ret != EOK) { |
|
|
|
MS_LOG(ERROR) << "Lookup id memcpy failed."; |
|
|
|
return false; |
|
|
|
|