From 15be0cc819ebc99d5355838de856921c176eabd9 Mon Sep 17 00:00:00 2001 From: ZPaC Date: Wed, 23 Sep 2020 21:09:39 +0800 Subject: [PATCH] Set weight size limit for PS mode. --- .../cpu/ps/embedding_look_up_proxy_kernel.cc | 4 ++++ mindspore/ccsrc/ps/parameter_server.h | 6 ++++-- mindspore/ccsrc/ps/worker.h | 3 +++ 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc index eeaef7974d..a6c8b7f314 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc @@ -29,6 +29,10 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) { for (auto dim : input_shape) { input_dims_ *= dim; } + if (input_dims_ * sizeof(float) > INT_MAX) { + MS_LOG(EXCEPTION) << "PS mode embedding lookup max embedding table size is " << INT_MAX << ", current shape " + << input_shape << " is too large."; + } if (mindspore::ps::Util::IsRoleOfWorker()) { key_ = AnfAlgo::GetNodeAttr(kernel_node, kAttrPsKey); diff --git a/mindspore/ccsrc/ps/parameter_server.h b/mindspore/ccsrc/ps/parameter_server.h index b19dace437..86318a1875 100644 --- a/mindspore/ccsrc/ps/parameter_server.h +++ b/mindspore/ccsrc/ps/parameter_server.h @@ -486,7 +486,8 @@ void ParameterServer::InitEmbeddingTable( // Init embedding weight const std::vector &input_shapes = lookup->input_sizes(); - size_t total_dims = std::accumulate(input_shapes.begin(), input_shapes.end(), 1, std::multiplies()); + size_t total_dims = + std::accumulate(input_shapes.begin(), input_shapes.end(), IntToSize(1), std::multiplies()); WeightPtr embedding = std::make_shared(total_dims, 0); MS_EXCEPTION_IF_NULL(embedding); T *embedding_data = embedding->data(); @@ -732,7 +733,8 @@ void ParameterServer::SyncEmbeddingTables() { for (auto embedding_table : embedding_tables_) { Key key = embedding_table.first; if (embedding_lookup_ops_.count(key) == 0) { - MS_LOG(EXCEPTION) << "Can't find look up PS kernel for key " << key; + MS_LOG(WARNING) << "Can't find look up PS kernel for key " << key; + continue; } auto lookup = embedding_lookup_ops_[key]; const std::vector &input_shapes = lookup->input_sizes(); diff --git a/mindspore/ccsrc/ps/worker.h b/mindspore/ccsrc/ps/worker.h index eda728128d..f07154e8f8 100644 --- a/mindspore/ccsrc/ps/worker.h +++ b/mindspore/ccsrc/ps/worker.h @@ -325,6 +325,9 @@ void Worker::InitPSParamAndOptim(const std::string ¶m_name, const tensor: MS_EXCEPTION_IF_NULL(tensor); void *param_data = tensor->data_c(); size_t param_size = LongToSize(tensor->data().nbytes()); + if (param_size > INT_MAX) { + MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is " << param_size; + } ShapeVector param_shape = tensor->shape_c(); size_t param_key = GetParamKey(param_name);