diff --git a/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.h b/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.h index 0fe17a5d73..7019e694bd 100644 --- a/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.h +++ b/mindspore/ccsrc/ps/ps_cache/embedding_hash_map.h @@ -51,6 +51,11 @@ class EmbeddingHashMap { graph_running_index_pos_(0), expired_element_full_(false) { hash_map_elements_.resize(hash_capacity); + // In multi-device mode, embedding table are distributed on different devices by ID interval, + // and IDs outside the range of local device will use the front and back positions of the table, + // the positions are reserved for this. + hash_map_elements_.front().set_step(SIZE_MAX); + hash_map_elements_.back().set_step(SIZE_MAX); graph_running_index_ = std::make_unique(hash_capacity); } virtual ~EmbeddingHashMap() = default;