|
|
|
@@ -413,7 +413,7 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, |
|
|
|
RETURN_IF_FALSE(ParseHostDataHostToDevice(id)); |
|
|
|
} |
|
|
|
if (need_swap_device_to_host) { |
|
|
|
RETURN_IF_FALSE(ParseHostDataDeviceToHost(id)); |
|
|
|
RETURN_IF_FALSE(ParseHostDataDeviceToHost()); |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
@@ -515,7 +515,7 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) { |
|
|
|
bool PsCacheManager::ParseHostDataDeviceToHost() { |
|
|
|
MS_ERROR_IF_NULL(embedding_device_cache_); |
|
|
|
int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); |
|
|
|
int *device_to_host_index = embedding_host_cache_->device_to_host_index.get(); |
|
|
|
@@ -536,8 +536,8 @@ bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) { |
|
|
|
int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); |
|
|
|
int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); |
|
|
|
while (true) { |
|
|
|
auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, |
|
|
|
graph_running_step_, &statistics_info_.host_to_server_size_); |
|
|
|
auto index = host_hash_map->ParseData(swap_device_to_host_id, host_to_server_index, host_to_server_ids, |
|
|
|
data_step_, graph_running_step_, &statistics_info_.host_to_server_size_); |
|
|
|
if (index == INVALID_INDEX_VALUE) { |
|
|
|
RETURN_IF_FALSE(WaitGraphRun()); |
|
|
|
continue; |
|
|
|
|