Browse Source

!6807 Set weight size limit for PS mode.

Merge pull request !6807 from ZPaC/master-embedding-lookup-large-tensor-limit
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
2e1742ead7
3 changed files with 11 additions and 2 deletions
  1. +4
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc
  2. +4
    -2
      mindspore/ccsrc/ps/parameter_server.h
  3. +3
    -0
      mindspore/ccsrc/ps/worker.h

+ 4
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/ps/embedding_look_up_proxy_kernel.cc View File

@@ -29,6 +29,10 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
for (auto dim : input_shape) {
input_dims_ *= dim;
}
if (input_dims_ * sizeof(float) > INT_MAX) {
MS_LOG(EXCEPTION) << "PS mode embedding lookup max embedding table size is " << INT_MAX << ", current shape "
<< input_shape << " is too large.";
}

if (mindspore::ps::Util::IsRoleOfWorker()) {
key_ = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPsKey);


+ 4
- 2
mindspore/ccsrc/ps/parameter_server.h View File

@@ -486,7 +486,8 @@ void ParameterServer<T>::InitEmbeddingTable(
// Init embedding weight
const std::vector<size_t> &input_shapes = lookup->input_sizes();
size_t total_dims = std::accumulate(input_shapes.begin(), input_shapes.end(), 1, std::multiplies<size_t>());
size_t total_dims =
std::accumulate(input_shapes.begin(), input_shapes.end(), IntToSize(1), std::multiplies<size_t>());
WeightPtr embedding = std::make_shared<Weight>(total_dims, 0);
MS_EXCEPTION_IF_NULL(embedding);
T *embedding_data = embedding->data();
@@ -732,7 +733,8 @@ void ParameterServer<T>::SyncEmbeddingTables() {
for (auto embedding_table : embedding_tables_) {
Key key = embedding_table.first;
if (embedding_lookup_ops_.count(key) == 0) {
MS_LOG(EXCEPTION) << "Can't find look up PS kernel for key " << key;
MS_LOG(WARNING) << "Can't find look up PS kernel for key " << key;
continue;
}
auto lookup = embedding_lookup_ops_[key];
const std::vector<size_t> &input_shapes = lookup->input_sizes();


+ 3
- 0
mindspore/ccsrc/ps/worker.h View File

@@ -325,6 +325,9 @@ void Worker<T>::InitPSParamAndOptim(const std::string &param_name, const tensor:
MS_EXCEPTION_IF_NULL(tensor);
void *param_data = tensor->data_c();
size_t param_size = LongToSize(tensor->data().nbytes());
if (param_size > INT_MAX) {
MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is " << param_size;
}
ShapeVector param_shape = tensor->shape_c();

size_t param_key = GetParamKey(param_name);


Loading…
Cancel
Save