Browse Source

!13227 remove embeddinglookup input0 int64

From: @fangzehua
Reviewed-by: 
Signed-off-by:
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 5 years ago
parent
commit
d943d22b49
2 changed files with 10 additions and 9 deletions
  1. +0
    -5
      mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h
  2. +10
    -4
      mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc

+ 0
- 5
mindspore/ccsrc/backend/kernel_compiler/cpu/embedding_look_up_cpu_kernel.h View File

@@ -54,11 +54,6 @@ MS_REG_CPU_KERNEL(
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
EmbeddingLookUpCPUKernel);

MS_REG_CPU_KERNEL(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
EmbeddingLookUpCPUKernel);

MS_REG_CPU_KERNEL(
EmbeddingLookup,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),


+ 10
- 4
mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc View File

@@ -79,8 +79,8 @@ bool CheckHostCacheParamSize(const ParamSet &parameter_cache_enable_set) {
cache_size = tmp_cache_size;
host_size = tmp_host_size;
}
if (cache_size >= host_size) {
MS_LOG(WARNING) << "vocab_cache_size >= vocab_size, there is no need use cache.";
if (cache_size > host_size) {
MS_LOG(WARNING) << "vocab_cache_size > vocab_size, there is no need use cache.";
return false;
}
return true;
@@ -444,12 +444,18 @@ CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const Param
size_t cnodes_size = cnodes.size();
CNodePtrList sparse_gather_v2_with_cache;
for (size_t i = 0; i < cnodes_size; ++i) {
if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2)) {
if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2) ||
IsPrimitiveCNode(cnodes[i], prim::kPrimEmbeddingLookup)) {
auto load_node = cnodes[i]->input(1);
if (IsPrimitiveCNode(load_node, prim::kPrimLoad) || IsPrimitiveCNode(load_node, prim::kPrimCast)) {
if (IsPrimitiveCNode(load_node, prim::kPrimCast)) {
load_node = load_node->cast<CNodePtr>()->input(1);
}
if (IsPrimitiveCNode(load_node, prim::kPrimLoad)) {
auto param_node = load_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>();
if (param_set.find(param_node) != param_set.end()) {
sparse_gather_v2_with_cache.push_back(cnodes[i]);
} else {
MS_LOG(EXCEPTION) << "EmbeddingLookup can't not support cache and no cache in the same graph.";
}
}
}


Loading…
Cancel
Save