| @@ -499,15 +499,20 @@ void TraverseGraphMap( | |||||
| for (auto &use : users) { | for (auto &use : users) { | ||||
| CNodePtr node = use.first->cast<CNodePtr>(); | CNodePtr node = use.first->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| if (node->func_graph() != fg) { | |||||
| continue; | |||||
| } | |||||
| int key = use.second; | int key = use.second; | ||||
| if (key != 0) { | if (key != 0) { | ||||
| MS_EXCEPTION_IF_NULL(node->input(0)); | MS_EXCEPTION_IF_NULL(node->input(0)); | ||||
| bool key_is_const = node->input(0)->isa<ValueNode>(); | bool key_is_const = node->input(0)->isa<ValueNode>(); | ||||
| PrimitivePtr value = GetValueNode<PrimitivePtr>(node->input(0)); | PrimitivePtr value = GetValueNode<PrimitivePtr>(node->input(0)); | ||||
| bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name())); | |||||
| bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name())); | |||||
| if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) { | |||||
| continue; | |||||
| if (value != nullptr) { | |||||
| bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name())); | |||||
| bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name())); | |||||
| if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) { | |||||
| continue; | |||||
| } | |||||
| } | } | ||||
| FuncGraphPtr g = get_prim_graph(GetValueNode<PrimitivePtr>(const_primitive_node), | FuncGraphPtr g = get_prim_graph(GetValueNode<PrimitivePtr>(const_primitive_node), | ||||
| dyn_cast<AbstractFunction>(const_primitive_node->abstract())); | dyn_cast<AbstractFunction>(const_primitive_node->abstract())); | ||||
| @@ -554,6 +559,7 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) { | |||||
| FuncGraphTransaction tr = manager_ptr->Transact(); | FuncGraphTransaction tr = manager_ptr->Transact(); | ||||
| auto &fgs = manager_ptr->func_graphs(); | auto &fgs = manager_ptr->func_graphs(); | ||||
| TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph); | TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph); | ||||
| tr.Commit(); | |||||
| return graph; | return graph; | ||||
| } | } | ||||