diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 93d5f33cf4..636d36f931 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -499,15 +499,20 @@ void TraverseGraphMap( for (auto &use : users) { CNodePtr node = use.first->cast(); MS_EXCEPTION_IF_NULL(node); + if (node->func_graph() != fg) { + continue; + } int key = use.second; if (key != 0) { MS_EXCEPTION_IF_NULL(node->input(0)); bool key_is_const = node->input(0)->isa(); PrimitivePtr value = GetValueNode(node->input(0)); - bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name())); - bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name())); - if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) { - continue; + if (value != nullptr) { + bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name())); + bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name())); + if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) { + continue; + } } FuncGraphPtr g = get_prim_graph(GetValueNode(const_primitive_node), dyn_cast(const_primitive_node->abstract())); @@ -554,6 +559,7 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) { FuncGraphTransaction tr = manager_ptr->Transact(); auto &fgs = manager_ptr->func_graphs(); TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph); + tr.Commit(); return graph; }