From 7eb49cfce7d67049dc0f0183dbbc19496116cfca Mon Sep 17 00:00:00 2001 From: lizhenyu Date: Mon, 28 Dec 2020 18:10:37 +0800 Subject: [PATCH] [bugfix] server core dump after traning --- mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc | 8 ++++---- mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h | 2 +- mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc | 2 +- mindspore/parallel/_auto_parallel_context.py | 5 +++++ 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc index 6e2babb52a..050065b211 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc @@ -413,7 +413,7 @@ bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, RETURN_IF_FALSE(ParseHostDataHostToDevice(id)); } if (need_swap_device_to_host) { - RETURN_IF_FALSE(ParseHostDataDeviceToHost(id)); + RETURN_IF_FALSE(ParseHostDataDeviceToHost()); } } return true; @@ -515,7 +515,7 @@ bool PsCacheManager::ParseHostDataHostToDevice(size_t id) { return true; } -bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) { +bool PsCacheManager::ParseHostDataDeviceToHost() { MS_ERROR_IF_NULL(embedding_device_cache_); int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); int *device_to_host_index = embedding_host_cache_->device_to_host_index.get(); @@ -536,8 +536,8 @@ bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) { int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); while (true) { - auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, - graph_running_step_, &statistics_info_.host_to_server_size_); + auto index = host_hash_map->ParseData(swap_device_to_host_id, host_to_server_index, host_to_server_ids, + data_step_, graph_running_step_, &statistics_info_.host_to_server_size_); if (index == INVALID_INDEX_VALUE) { RETURN_IF_FALSE(WaitGraphRun()); continue; diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h index 7dca6b0159..54dfa217f9 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h @@ -150,7 +150,7 @@ class PsCacheManager { bool WaitGraphRun(); bool ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device, int *hash_index); bool ParseHostDataHostToDevice(size_t id); - bool ParseHostDataDeviceToHost(size_t id); + bool ParseHostDataDeviceToHost(); bool HashSwapDeviceOut(int *swap_out_index, ::ps::SArray *swap_out_data, const HashTableInfo &hash_info); bool HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key); bool HashSwapHostToDevice(const HashTableInfo &hash_info); diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc index 28fa0e701c..f7263d72df 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime_manager.cc @@ -24,7 +24,7 @@ namespace mindspore { namespace device { void KernelRuntimeManager::ClearRuntimeResource() { #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) - if (ps::PsDataPrefetch::GetInstance().cache_enable()) { + if (ps::Util::IsRoleOfWorker() && ps::PsDataPrefetch::GetInstance().cache_enable()) { ps::ps_cache_instance.SyncEmbeddingTable(); } #endif diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 24cb8702af..ed68fdf76b 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -16,6 +16,7 @@ import threading import mindspore.context as context from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size +from mindspore.parallel._ps_context import _is_role_pserver from mindspore._c_expression import AutoParallelContext from mindspore._checkparam import args_type_check @@ -180,6 +181,8 @@ class _AutoParallelContext: def get_parallel_mode(self): """Get parallel mode.""" self.check_context_handle() + if _is_role_pserver(): + return context.ParallelMode.STAND_ALONE return self._context_handle.get_parallel_mode() def set_strategy_search_mode(self, auto_parallel_search_mode): @@ -242,6 +245,8 @@ class _AutoParallelContext: def get_full_batch(self): """Get whether load full batch on each device.""" self.check_context_handle() + if _is_role_pserver(): + return False return self._context_handle.get_full_batch() def set_strategy_ckpt_save_file(self, strategy_ckpt_save_file):