Browse Source

!13035 fix mix cache

From: @fangzehua
Reviewed-by: @kisnwang,@liangchenghui
Signed-off-by: @liangchenghui
tags/v1.2.0-rc1
mindspore-ci-bot Gitee 4 years ago
parent
commit
41b78ddc7e
1 changed files with 1 additions and 4 deletions
  1. +1
    -4
      mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc

+ 1
- 4
mindspore/ccsrc/frontend/parallel/cache_embedding/cache_embedding.cc View File

@@ -446,10 +446,7 @@ CNodePtrList FindSparseGatherV2WithCache(const CNodePtrList &cnodes, const Param
for (size_t i = 0; i < cnodes_size; ++i) {
if (IsPrimitiveCNode(cnodes[i], prim::kPrimSparseGatherV2)) {
auto load_node = cnodes[i]->input(1);
if (IsPrimitiveCNode(load_node, prim::kPrimCast)) {
load_node = load_node->cast<CNodePtr>()->input(1);
}
if (IsPrimitiveCNode(load_node, prim::kPrimLoad)) {
if (IsPrimitiveCNode(load_node, prim::kPrimLoad) || IsPrimitiveCNode(load_node, prim::kPrimCast)) {
auto param_node = load_node->cast<CNodePtr>()->input(1)->cast<ParameterPtr>();
if (param_set.find(param_node) != param_set.end()) {
sparse_gather_v2_with_cache.push_back(cnodes[i]);


Loading…
Cancel
Save