diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h index b1639100da..cbd7fe30a0 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h @@ -54,11 +54,6 @@ MS_REG_CPU_KERNEL( KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), EmbeddingLookUpCPUKernel); -MS_REG_CPU_KERNEL( - EmbeddingLookup, - KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - EmbeddingLookUpCPUKernel); - MS_REG_CPU_KERNEL( EmbeddingLookup, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), diff --git a/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc b/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc index cb09b7db31..e6d00e7620 100644 --- a/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc +++ b/mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc @@ -79,8 +79,8 @@ bool CheckHostCacheParamSize(const ParamSet ¶meter_cache_enable_set) { cache_size = tmp_cache_size; host_size = tmp_host_size; } - if (cache_size >= host_size) { - MS_LOG(WARNING) << "vocab_cache_size >= vocab_size, there is no need use cache."; + if (cache_size > host_size) { + MS_LOG(WARNING) << "vocab_cache_size > vocab_size, there is no need use cache."; return false; } return true; @@ -444,12 +444,18 @@ CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const Param size_t cnodes_size = cnodes.size(); CNodePtrList sparse_gather_v2_with_cache; for (size_t i = 0; i < cnodes_size; ++i) { - if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2)) { + if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2) || + IsPrimitiveCNode(cnodes[i], prim::kPrimEmbeddingLookup)) { auto load_node = cnodes[i]->input(1); - if (IsPrimitiveCNode(load_node, prim::kPrimLoad) || IsPrimitiveCNode(load_node, prim::kPrimCast)) { + if (IsPrimitiveCNode(load_node, prim::kPrimCast)) { + load_node = load_node->cast()->input(1); + } + if (IsPrimitiveCNode(load_node, prim::kPrimLoad)) { auto param_node = load_node->cast()->input(1)->cast(); if (param_set.find(param_node) != param_set.end()) { sparse_gather_v2_with_cache.push_back(cnodes[i]); + } else { + MS_LOG(EXCEPTION) << "EmbeddingLookup can't not support cache and no cache in the same graph."; } } }