diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 89583a3aa3..a9a13d11c5 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -30,6 +30,7 @@ #include "utils/ms_utils.h" #include "utils/shape_utils.h" #include "utils/utils.h" +#include "frontend/parallel/context.h" #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #include "ps/ps_cache/ps_cache_manager.h" #endif @@ -1057,10 +1058,13 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, MS_EXCEPTION_IF_NULL(input_index.first); } if ((!input_index.first->isa()) && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) { - MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from " - << input_index.first->fullname_with_scope(); - MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in " - "parameter server training mode."; + bool full_batch = parallel::ParallelContext::GetInstance()->full_batch(); + if ((!full_batch) || (AnfAlgo::GetCNodeName(input_index.first) != "Minimum")) { + MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() + << ") cache is from " << input_index.first->fullname_with_scope(); + MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in " + "parameter server training mode."; + } } *first_cache_input_index = input_index.first; *first_cache_size = size; @@ -1099,7 +1103,7 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *g auto size = ps::ps_cache_instance.QueryHashTableSize(param_name); if (size != first_cache_size) { MS_LOG(ERROR) << "The cache size(" << size << ") of embeddingLookup(" << kernel->fullname_with_scope() - << ") is not the same as other embeddingLookup cache size."; + << ") is not the same as other embeddingLookup cache size(" << first_cache_size << ")."; MS_LOG(EXCEPTION) << "The cache sizes of embeddingLookups are not the same in parameter server training mode."; } } else if (ps::ps_cache_instance.IsHashTable(param_name)) { diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index e31d1f87e7..f73c9bfbb1 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -277,7 +277,7 @@ class EmbeddingLookup(Cell): if is_auto_parallel: device_num = get_group_size() full_batch = _get_full_batch() - if device_num > 1 and not (full_batch and slice_mode == TABLE_ROW_SLICE): + if device_num > 1 and not (full_batch and slice_mode == "table_row_slice"): raise ValueError("The embeddingLookup cache of parameter server parallel only be used " "in 'full_batch' and 'table_row_slice' parallel strategy.") self.vocab_cache_size = self.vocab_cache_size * device_num