From: @gaoxiong1 Reviewed-by: @anyrenwei,@dylangeng Signed-off-by: @dylangengpull/15658/MERGE
| @@ -30,7 +30,7 @@ class GraphSplitByPattern: | |||
| self.stitch_ops = set() | |||
| self.stitch_atomic_ops = set() | |||
| def __init__(self, init_op, is_output): | |||
| def __init__(self, init_op, is_output, unique_id): | |||
| self.pattern = PrimLib.iter_type(init_op) | |||
| self.ops = [init_op] | |||
| self.in_relations = dict() # {area1: relation1, area2: relation2, ...} | |||
| @@ -48,6 +48,8 @@ class GraphSplitByPattern: | |||
| else: | |||
| _gather_reduce_exclude(to) | |||
| _gather_reduce_exclude(init_op) | |||
| self.reach_map = dict() | |||
| self.unique_id = unique_id | |||
| def __str__(self): | |||
| return '<' + '-'.join([op.output.name for op in self.ops]) + '>' | |||
| @@ -122,13 +124,20 @@ class GraphSplitByPattern: | |||
| if area.output_excluded: | |||
| self.output_excluded.update(area.output_excluded) | |||
| self.update_stitch_info(area.stitch_info) | |||
| for to, reach in area.reach_map.items(): | |||
| if reach and not self.reach_map.get(to, False): | |||
| self.reach_map[to] = True | |||
| def check_acyclic(self, to): | |||
| """Check circle. It returns false if circle exists""" | |||
| def _reached(area, to): | |||
| if to.unique_id in area.reach_map: | |||
| return area.reach_map[to.unique_id] | |||
| for out, _ in area.out_relations.items(): | |||
| if out == to or _reached(out, to): | |||
| area.reach_map[to.unique_id] = True | |||
| return True | |||
| area.reach_map[to.unique_id] = False | |||
| return False | |||
| for out, _ in self.out_relations.items(): | |||
| if out != to and _reached(out, to): | |||
| @@ -151,9 +160,11 @@ class GraphSplitByPattern: | |||
| self.flags = flags | |||
| area_map = {} | |||
| _, outputs = graph.deduce_parameters() | |||
| idx = 0 | |||
| for op in graph.ops: | |||
| is_output = op.output in outputs | |||
| a = self.Area(op, is_output) | |||
| a = self.Area(op, is_output, idx) | |||
| idx += 1 | |||
| self.set_default_mode(a) | |||
| self.areas.append(a) | |||
| area_map[op] = a | |||
| @@ -249,7 +260,7 @@ class GraphSplitByPattern: | |||
| break | |||
| if out_reshape_ops: | |||
| for op in out_reshape_ops: | |||
| a = self.Area(op, False) | |||
| a = self.Area(op, False, -1) | |||
| self.set_default_mode(a) | |||
| new_areas.append(a) | |||
| area.ops = remain_ops | |||
| @@ -448,45 +448,26 @@ class Graph: | |||
| class GraphVisitor: | |||
| """Graph visitor""" | |||
| def __init__(self, forward=True, once_mode=True): | |||
| def __init__(self, forward=True): | |||
| self.forward = forward | |||
| self.once_mode = once_mode | |||
| if self.once_mode: | |||
| self.visited = set() | |||
| def visit_graph(self, graph): | |||
| """Visit graph""" | |||
| inputs, outputs = graph.deduce_parameters() | |||
| if self.forward: | |||
| for tensor in inputs: | |||
| for op in tensor.to_ops: | |||
| self.visit(op) | |||
| for op in graph.ops: | |||
| self.visit(op) | |||
| else: | |||
| for tensor in outputs: | |||
| if not tensor.to_ops: | |||
| self.visit(tensor.op) | |||
| def visit(self, op): | |||
| """Visit op""" | |||
| next_ops = op.output.to_ops if self.forward else [ | |||
| t.op for t in op.inputs if t.op is not None] | |||
| if self.once_mode: | |||
| self.visited.add(op) | |||
| for n in next_ops: | |||
| if n not in self.visited: | |||
| self.visit(n) | |||
| else: | |||
| for n in next_ops: | |||
| self.visit(n) | |||
| for i in range(len(graph.ops)-1, -1, -1): | |||
| self.visit(graph.ops[i]) | |||
| class AlignShape(GraphVisitor): | |||
| """Align shape""" | |||
| def __init__(self): | |||
| super().__init__(once_mode=False) | |||
| super().__init__() | |||
| def visit(self, op): | |||
| """Visit op node""" | |||
| prim = PrimLib.get_prim(op) | |||
| if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE): | |||
| out_dim = len(op.output.shape) | |||
| @@ -496,8 +477,6 @@ class AlignShape(GraphVisitor): | |||
| align_dim = len(t.shape) | |||
| if align_dim > out_dim: | |||
| op.output.shape = [1] * (align_dim - out_dim) + op.output.shape | |||
| super().visit(op) | |||
| class AddControlBuddy(GraphVisitor): | |||
| """Add control buddy""" | |||
| @@ -507,6 +486,7 @@ class AddControlBuddy(GraphVisitor): | |||
| self.buddies = {} # {op : [ctrl_op]} | |||
| def visit(self, op): | |||
| """Visit op node""" | |||
| if op.prim == "MakeTuple": | |||
| assert len(op.output.to_ops) == 1 | |||
| owner = op.output.to_ops[0] | |||
| @@ -517,9 +497,9 @@ class AddControlBuddy(GraphVisitor): | |||
| if op in self.buddies: | |||
| ops = self.buddies.pop(op) | |||
| self.buddies[owner].extend(ops) | |||
| super().visit(op) | |||
| def visit_graph(self, graph): | |||
| """Visit graph nodes""" | |||
| super().visit_graph(graph) | |||
| for owner in self.buddies: | |||
| for op in self.buddies[owner]: | |||