diff --git a/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc b/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc index bf4fb776ff..2d7715f284 100644 --- a/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc +++ b/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc @@ -446,6 +446,9 @@ CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const Param for (size_t i = 0; i < cnodes_size; ++i) { if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2)) { auto load_node = cnodes[i]->input(1); + if (IsPrimitiveCNode(load_node, prim::kPrimCast)) { + load_node = load_node->cast()->input(1); + } if (IsPrimitiveCNode(load_node, prim::kPrimLoad)) { auto param_node = load_node->cast()->input(1)->cast(); if (param_set.find(param_node) != param_set.end()) { diff --git a/mindspore/train/callback/_checkpoint.py b/mindspore/train/callback/_checkpoint.py index db50fea057..3f0b85dfcc 100644 --- a/mindspore/train/callback/_checkpoint.py +++ b/mindspore/train/callback/_checkpoint.py @@ -261,6 +261,7 @@ class ModelCheckpoint(Callback): self._manager = CheckpointManager() self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix) self._graph_saved = False + self._need_flush_from_cache = True def step_end(self, run_context): """ @@ -326,7 +327,8 @@ class ModelCheckpoint(Callback): return # if param is cache enable, flush data from cache to host before save_ckpt - self._flush_from_cache(cb_params) + if self._need_flush_from_cache: + self._flush_from_cache(cb_params) save_ckpt = self._check_save_ckpt(cb_params, force_to_save) step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1) @@ -365,10 +367,14 @@ class ModelCheckpoint(Callback): def _flush_from_cache(self, cb_params): """Flush cache data to host if tensor is cache enable.""" + has_cache_params = False params = cb_params.train_network.get_parameters() for param in params: if param.cache_enable: + has_cache_params = True Tensor(param).flush_from_cache() + if not has_cache_params: + self._need_flush_from_cache = False @property def latest_ckpt_file_name(self):