|
|
|
@@ -30,6 +30,7 @@ |
|
|
|
#include "utils/ms_utils.h" |
|
|
|
#include "utils/shape_utils.h" |
|
|
|
#include "utils/utils.h" |
|
|
|
#include "frontend/parallel/context.h" |
|
|
|
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) |
|
|
|
#include "ps/ps_cache/ps_cache_manager.h" |
|
|
|
#endif |
|
|
|
@@ -1057,10 +1058,13 @@ void KernelRuntime::GetFirstPSEmbeddingCache(const session::KernelGraph *graph, |
|
|
|
MS_EXCEPTION_IF_NULL(input_index.first); |
|
|
|
} |
|
|
|
if ((!input_index.first->isa<Parameter>()) && (AnfAlgo::GetCNodeName(input_index.first) != "GetNext")) { |
|
|
|
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() << ") cache is from " |
|
|
|
<< input_index.first->fullname_with_scope(); |
|
|
|
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in " |
|
|
|
"parameter server training mode."; |
|
|
|
bool full_batch = parallel::ParallelContext::GetInstance()->full_batch(); |
|
|
|
if ((!full_batch) || (AnfAlgo::GetCNodeName(input_index.first) != "Minimum")) { |
|
|
|
MS_LOG(ERROR) << "The input index of the embeddingLookup(" << kernel->fullname_with_scope() |
|
|
|
<< ") cache is from " << input_index.first->fullname_with_scope(); |
|
|
|
MS_LOG(EXCEPTION) << "The embeddingLookup whose input index isn't from dataset doesn't support cache in " |
|
|
|
"parameter server training mode."; |
|
|
|
} |
|
|
|
} |
|
|
|
*first_cache_input_index = input_index.first; |
|
|
|
*first_cache_size = size; |
|
|
|
@@ -1099,7 +1103,7 @@ void KernelRuntime::CheckIfSupportPSEmbeddingCache(const session::KernelGraph *g |
|
|
|
auto size = ps::ps_cache_instance.QueryHashTableSize(param_name); |
|
|
|
if (size != first_cache_size) { |
|
|
|
MS_LOG(ERROR) << "The cache size(" << size << ") of embeddingLookup(" << kernel->fullname_with_scope() |
|
|
|
<< ") is not the same as other embeddingLookup cache size."; |
|
|
|
<< ") is not the same as other embeddingLookup cache size(" << first_cache_size << ")."; |
|
|
|
MS_LOG(EXCEPTION) << "The cache sizes of embeddingLookups are not the same in parameter server training mode."; |
|
|
|
} |
|
|
|
} else if (ps::ps_cache_instance.IsHashTable(param_name)) { |
|
|
|
|