| @@ -16,6 +16,7 @@ | |||
| from .model import PrimLib, Graph, Tensor | |||
| use_poly_reduce = False | |||
| class GraphSplitByPattern: | |||
| """Graph splitter""" | |||
| @@ -24,14 +25,25 @@ class GraphSplitByPattern: | |||
| MODE_BASIC = 1 | |||
| MODE_COMPOSITE = 2 | |||
| def __init__(self, init_op): | |||
| def __init__(self, init_op, is_output): | |||
| self.pattern = PrimLib.iter_type(init_op) | |||
| self.ops = [init_op] | |||
| self.in_relations = dict() # {area1: relation1, area2: relation2, ...} | |||
| self.out_relations = dict() # {area1: relation1, area2: relation2, ...} | |||
| self.mode = self.MODE_BASIC | |||
| if self.pattern == PrimLib.TRANSFORM: | |||
| if self.pattern == PrimLib.TRANSFORM or (use_poly_reduce and self.pattern == PrimLib.REDUCE): | |||
| self.mode = self.MODE_COMPOSITE | |||
| self.is_output = is_output | |||
| self.output_excluded = set() | |||
| if self.pattern == PrimLib.REDUCE: | |||
| def _gather_reduce_exclude(op): | |||
| for to in op.output.to_ops: | |||
| idx = to.inputs.index(op.output) | |||
| if self.get_relation(to, idx) > PrimLib.ELEMWISE: | |||
| self.output_excluded.add(to) | |||
| else: | |||
| _gather_reduce_exclude(to) | |||
| _gather_reduce_exclude(init_op) | |||
| def __str__(self): | |||
| return '<' + '-'.join([op.output.name for op in self.ops]) + '>' | |||
| @@ -39,18 +51,21 @@ class GraphSplitByPattern: | |||
| def __repr__(self): | |||
| return str(self) | |||
| def get_relation(self, op, i): | |||
| relation = PrimLib.UNKNOWN | |||
| _, elem_relation = PrimLib.input_relation(op, i) | |||
| for r in elem_relation: | |||
| if r is None: | |||
| relation = max(relation, PrimLib.BROADCAST) | |||
| elif r > relation: | |||
| relation = r | |||
| return relation | |||
| def link_input(self, area_map): | |||
| """Link inputs""" | |||
| def get_relation(op, i): | |||
| relation = PrimLib.UNKNOWN | |||
| _, elem_relation = PrimLib.input_relation(op, i) | |||
| for r in elem_relation: | |||
| if r is not None and r > relation: | |||
| relation = r | |||
| return relation | |||
| for i, t in enumerate(self.ops[0].inputs): | |||
| if t.op is not None: | |||
| area, relation = area_map[t.op], get_relation(self.ops[0], i) | |||
| area, relation = area_map[t.op], self.get_relation(self.ops[0], i) | |||
| self.in_relations[area] = relation | |||
| def link_output(self): | |||
| @@ -79,7 +94,10 @@ class GraphSplitByPattern: | |||
| r = rels.pop(area) | |||
| _update_relation(rels, self, r) | |||
| self.ops.extend(area.ops) | |||
| if self.pattern >= area.pattern: | |||
| self.ops.extend(area.ops) | |||
| else: | |||
| self.ops = area.ops + self.ops | |||
| _update_pattern() | |||
| _fuse_relation(self.in_relations, area.in_relations) | |||
| _fuse_relation(self.out_relations, area.out_relations) | |||
| @@ -89,6 +107,10 @@ class GraphSplitByPattern: | |||
| _redirect_relation(a.in_relations) | |||
| if self.pattern > PrimLib.RESHAPE: | |||
| self.mode = self.MODE_COMPOSITE | |||
| if area.is_output and not self.is_output: | |||
| self.is_output = True | |||
| if area.output_excluded: | |||
| self.output_excluded.update(area.output_excluded) | |||
| def check_circle(self, to): | |||
| """Check circle. It returns false if circle exists""" | |||
| @@ -102,15 +124,27 @@ class GraphSplitByPattern: | |||
| return False | |||
| return True | |||
| BORADCAST_FUSE_DEPTH = 3 | |||
| REDUCE_FUSE_DEPTH = 3 | |||
| def dom_op(self): | |||
| return self.ops[0] | |||
| def reduce_out_exclude(self, area): | |||
| if self.output_excluded: | |||
| for op in self.output_excluded: | |||
| if op in area.ops: | |||
| return True | |||
| return False | |||
| BORADCAST_FUSE_DEPTH = 20 | |||
| REDUCE_FUSE_DEPTH = 20 | |||
| def __init__(self, graph): | |||
| self.graph = graph | |||
| self.areas = [] | |||
| area_map = {} | |||
| _, outputs = graph.deduce_parameters() | |||
| for op in graph.ops: | |||
| a = self.Area(op) | |||
| is_output = op.output in outputs | |||
| a = self.Area(op, is_output) | |||
| self.areas.append(a) | |||
| area_map[op] = a | |||
| for a in self.areas: | |||
| @@ -123,12 +157,20 @@ class GraphSplitByPattern: | |||
| changed = False | |||
| while True: | |||
| for dominant in self.areas: | |||
| fuse_areas = selector(dominant) | |||
| if fuse_areas: | |||
| for area in fuse_areas: | |||
| changed = True | |||
| dominant.fuse(area) | |||
| self.areas.remove(area) | |||
| result = selector(dominant) | |||
| if result is not None and result[0]: | |||
| fuse_areas, is_forward = result | |||
| if is_forward: | |||
| for area in fuse_areas: | |||
| dominant.fuse(area) | |||
| self.areas.remove(area) | |||
| else: | |||
| forward_area = dominant | |||
| for area in fuse_areas: | |||
| area.fuse(forward_area) | |||
| self.areas.remove(forward_area) | |||
| forward_area = area | |||
| changed = True | |||
| break | |||
| else: | |||
| return changed | |||
| @@ -148,43 +190,69 @@ class GraphSplitByPattern: | |||
| def split(self): | |||
| """Split graph by pattern""" | |||
| def _reshape(dom): | |||
| if dom.pattern != PrimLib.RESHAPE: | |||
| return None | |||
| min_area, forward_fuse = None, False | |||
| for a, _ in dom.out_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and dom.check_circle(a) and \ | |||
| (min_area is None or a.pattern < min_area.pattern): | |||
| min_area = a | |||
| for a, _ in dom.in_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom) and \ | |||
| len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ | |||
| (min_area is None or a.pattern < min_area.pattern): | |||
| min_area, forward_fuse = a, True | |||
| return ([min_area], forward_fuse) if min_area else None | |||
| def _elemwise_depth(dom): | |||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1: | |||
| return None | |||
| a, r = list(dom.in_relations.items())[0] | |||
| if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 and r != PrimLib.ELEMWISE: | |||
| if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r != PrimLib.ELEMWISE or \ | |||
| a.dom_op().output.shape != dom.dom_op().output.shape: | |||
| return None | |||
| return [a] | |||
| return [a], True | |||
| def _elemwise_width(dom): | |||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST): | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.in_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom): | |||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom) and \ | |||
| a.dom_op().output.shape == dom.dom_op().output.shape: | |||
| fused.append(a) | |||
| return fused | |||
| return fused, True | |||
| def _broadcast_pat_exclude(dom, a, r): | |||
| if use_poly_reduce and a.pattern == PrimLib.REDUCE: | |||
| return dom.pattern > PrimLib.ELEMWISE or r > PrimLib.ELEMWISE | |||
| return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST | |||
| def _broadcast_depth(dom): | |||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1: | |||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \ | |||
| dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH: | |||
| return None | |||
| a, r = list(dom.in_relations.items())[0] | |||
| if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \ | |||
| r != PrimLib.BROADCAST or len(a.ops) > self.BORADCAST_FUSE_DEPTH: | |||
| a, r = list(dom.out_relations.items())[0] | |||
| if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1: | |||
| return None | |||
| return [a] | |||
| return [a], False | |||
| def _broadcast_width(dom): | |||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST): | |||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \ | |||
| dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH: | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.in_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.BROADCAST and \ | |||
| a.check_circle(dom) and len(a.ops) <= self.BORADCAST_FUSE_DEPTH: | |||
| fused.append(a) | |||
| return fused | |||
| for a, r in dom.out_relations.items(): | |||
| if _broadcast_pat_exclude(dom, a, r) or not dom.check_circle(a) or \ | |||
| (fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): | |||
| return None | |||
| fused.append(a) | |||
| return fused, False | |||
| def _check_reduce_exclude(dom): | |||
| if use_poly_reduce: | |||
| return False | |||
| # exclude large all-reduce | |||
| if len(dom.ops[0].inputs[0].shape) == len(dom.ops[0].attrs["reduce_axis"]) and \ | |||
| dom.ops[0].inputs[0].get_size() > 10000: | |||
| @@ -198,16 +266,22 @@ class GraphSplitByPattern: | |||
| return True | |||
| return False | |||
| def _reduce_pat_exclude(dom, a, r): | |||
| if len(a.ops) > self.REDUCE_FUSE_DEPTH: | |||
| return True | |||
| if use_poly_reduce: | |||
| return a.pattern > PrimLib.ELEMWISE or r > PrimLib.REDUCE or r == PrimLib.BROADCAST | |||
| return a.pattern > PrimLib.BROADCAST or r > PrimLib.REDUCE | |||
| def _reduce_depth(dom): | |||
| if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1: | |||
| return None | |||
| if _check_reduce_exclude(dom): | |||
| return None | |||
| a, r = list(dom.in_relations.items())[0] | |||
| if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \ | |||
| r > PrimLib.REDUCE or len(a.ops) > self.REDUCE_FUSE_DEPTH: | |||
| if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1: | |||
| return None | |||
| return [a] | |||
| return [a], True | |||
| def _reduce_width(dom): | |||
| if dom.pattern != PrimLib.REDUCE: | |||
| @@ -216,18 +290,51 @@ class GraphSplitByPattern: | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.in_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.REDUCE and \ | |||
| a.check_circle(dom) and len(a.ops) <= self.REDUCE_FUSE_DEPTH: | |||
| if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom): | |||
| fused.append(a) | |||
| return fused | |||
| return fused, True | |||
| def _tensor_size(tensor): | |||
| size = 1 | |||
| for i in tensor.shape: | |||
| size *= i | |||
| return size | |||
| def _reduce_output(dom): | |||
| if dom.pattern != PrimLib.REDUCE: | |||
| return None | |||
| is_all_reduce = _tensor_size(dom.ops[0].output) == 1 | |||
| # excluded large size all reduce | |||
| if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12: | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.out_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ | |||
| dom.check_circle(a) and not dom.reduce_out_exclude(a): | |||
| fused.append(a) | |||
| return fused, False | |||
| def _transpose(dom): | |||
| if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose": | |||
| return None | |||
| fused = [] | |||
| for a, _ in dom.in_relations.items(): | |||
| if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom): | |||
| fused.append(a) | |||
| return fused, True | |||
| changed = True | |||
| while changed: | |||
| changed = self.fuse(_elemwise_depth) | |||
| changed = self.fuse(_reshape) | |||
| changed = self.fuse(_elemwise_depth) or changed | |||
| changed = self.fuse(_elemwise_width) or changed | |||
| changed = self.fuse(_broadcast_depth) or changed | |||
| changed = self.fuse(_broadcast_width) or changed | |||
| changed = self.fuse(_reduce_depth) or changed | |||
| changed = self.fuse(_reduce_width) or changed | |||
| changed = self.fuse(_broadcast_depth) or changed | |||
| changed = self.fuse(_broadcast_width) or changed | |||
| if use_poly_reduce: | |||
| changed = self.fuse(_reduce_output) or changed | |||
| self.fuse(_transpose) | |||
| subgraphs, graphmodes = self.to_subgraphs() | |||
| return subgraphs, graphmodes | |||