diff --git a/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc b/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc index 2d7715f284..6960c72b7a 100644 --- a/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc +++ b/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc @@ -446,10 +446,7 @@ CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const Param for (size_t i = 0; i < cnodes_size; ++i) { if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2)) { auto load_node = cnodes[i]->input(1); - if (IsPrimitiveCNode(load_node, prim::kPrimCast)) { - load_node = load_node->cast()->input(1); - } - if (IsPrimitiveCNode(load_node, prim::kPrimLoad)) { + if (IsPrimitiveCNode(load_node, prim::kPrimLoad) || IsPrimitiveCNode(load_node, prim::kPrimCast)) { auto param_node = load_node->cast()->input(1)->cast(); if (param_set.find(param_node) != param_set.end()) { sparse_gather_v2_with_cache.push_back(cnodes[i]);