|
|
|
@@ -79,8 +79,8 @@ bool CheckHostCacheParamSize(const ParamSet ¶meter_cache_enable_set) { |
|
|
|
cache_size = tmp_cache_size; |
|
|
|
host_size = tmp_host_size; |
|
|
|
} |
|
|
|
if (cache_size >= host_size) { |
|
|
|
MS_LOG(WARNING) << "vocab_cache_size >= vocab_size, there is no need use cache."; |
|
|
|
if (cache_size > host_size) { |
|
|
|
MS_LOG(WARNING) << "vocab_cache_size > vocab_size, there is no need use cache."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
@@ -444,12 +444,18 @@ CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const Param |
|
|
|
size_t cnodes_size = cnodes.size(); |
|
|
|
CNodePtrList sparse_gather_v2_with_cache; |
|
|
|
for (size_t i = 0; i < cnodes_size; ++i) { |
|
|
|
if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2)) { |
|
|
|
if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2) || |
|
|
|
IsPrimitiveCNode(cnodes[i], prim::kPrimEmbeddingLookup)) { |
|
|
|
auto load_node = cnodes[i]->input(1); |
|
|
|
if (IsPrimitiveCNode(load_node, prim::kPrimLoad) || IsPrimitiveCNode(load_node, prim::kPrimCast)) { |
|
|
|
if (IsPrimitiveCNode(load_node, prim::kPrimCast)) { |
|
|
|
load_node = load_node->cast<CNodePtr>()->input(1); |
|
|
|
} |
|
|
|
if (IsPrimitiveCNode(load_node, prim::kPrimLoad)) { |
|
|
|
auto param_node = load_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>(); |
|
|
|
if (param_set.find(param_node) != param_set.end()) { |
|
|
|
sparse_gather_v2_with_cache.push_back(cnodes[i]); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "EmbeddingLookup can't not support cache and no cache in the same graph."; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|