Browse Source

!10116 add ps check

From: @limingqi107
Reviewed-by: @cristoval,@kisnwang
Signed-off-by: @cristoval
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
67c3fded73
2 changed files with 10 additions and 6 deletions
  1. +9
    -5
      mindspore/ccsrc/runtime/device/kernel_runtime.cc
  2. +1
    -1
      mindspore/nn/layer/embedding.py

+ 9
- 5
mindspore/ccsrc/runtime/device/kernel_runtime.cc View File

@@ -30,6 +30,7 @@
#include "utils/ms_utils.h"
#include "utils/shape_utils.h"
#include "utils/utils.h"
#include "frontend/parallel/context.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/ps_cache/ps_cache_manager.h"
#endif
@@ -1057,10 +1058,13 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph,
MS_EXCEPTION_IF_NULL(input_index.first);
}
if ((!input_index.first->isa<Parameter>()) && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) {
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from "
<< input_index.first->fullname_with_scope();
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
"parameter server training mode.";
bool full_batch = parallel::ParallelContext::GetInstance()->full_batch();
if ((!full_batch) || (AnfAlgo::GetCNodeName(input_index.first) != "Minimum")) {
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope()
<< ") cache is from " << input_index.first->fullname_with_scope();
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in "
"parameter server training mode.";
}
}
*first_cache_input_index = input_index.first;
*first_cache_size = size;
@@ -1099,7 +1103,7 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *g
auto size = ps::ps_cache_instance.QueryHashTableSize(param_name);
if (size != first_cache_size) {
MS_LOG(ERROR) << "The cache size(" << size << ") of embeddingLookup(" << kernel->fullname_with_scope()
<< ") is not the same as other embeddingLookup cache size.";
<< ") is not the same as other embeddingLookup cache size(" << first_cache_size << ").";
MS_LOG(EXCEPTION) << "The cache sizes of embeddingLookups are not the same in parameter server training mode.";
}
} else if (ps::ps_cache_instance.IsHashTable(param_name)) {


+ 1
- 1
mindspore/nn/layer/embedding.py View File

@@ -277,7 +277,7 @@ class EmbeddingLookup(Cell):
if is_auto_parallel:
device_num = get_group_size()
full_batch = _get_full_batch()
if device_num > 1 and not (full_batch and slice_mode == TABLE_ROW_SLICE):
if device_num > 1 and not (full_batch and slice_mode == "table_row_slice"):
raise ValueError("The embeddingLookup cache of parameter server parallel only be used "
"in 'full_batch' and 'table_row_slice' parallel strategy.")
self.vocab_cache_size = self.vocab_cache_size * device_num


Loading…
Cancel
Save