|
|
|
@@ -499,15 +499,20 @@ void TraverseGraphMap( |
|
|
|
for (auto &use : users) { |
|
|
|
CNodePtr node = use.first->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
|
if (node->func_graph() != fg) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
int key = use.second; |
|
|
|
if (key != 0) { |
|
|
|
MS_EXCEPTION_IF_NULL(node->input(0)); |
|
|
|
bool key_is_const = node->input(0)->isa<ValueNode>(); |
|
|
|
PrimitivePtr value = GetValueNode<PrimitivePtr>(node->input(0)); |
|
|
|
bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name())); |
|
|
|
bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name())); |
|
|
|
if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) { |
|
|
|
continue; |
|
|
|
if (value != nullptr) { |
|
|
|
bool is_prim_array_map = !(prim::kPrimArrayMap->name().compare(value->name())); |
|
|
|
bool is_prim_array_reduce = !(prim::kPrimArrayReduce->name().compare(value->name())); |
|
|
|
if (key == 1 && key_is_const && (is_prim_array_map || is_prim_array_reduce)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
} |
|
|
|
FuncGraphPtr g = get_prim_graph(GetValueNode<PrimitivePtr>(const_primitive_node), |
|
|
|
dyn_cast<AbstractFunction>(const_primitive_node->abstract())); |
|
|
|
@@ -554,6 +559,7 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) { |
|
|
|
FuncGraphTransaction tr = manager_ptr->Transact(); |
|
|
|
auto &fgs = manager_ptr->func_graphs(); |
|
|
|
TraverseGraphMap(manager_ptr, &tr, fgs, get_prim_graph); |
|
|
|
tr.Commit(); |
|
|
|
|
|
|
|
return graph; |
|
|
|
} |
|
|
|
|