| @@ -16,6 +16,7 @@ | |||||
| from .model import PrimLib, Graph, Tensor | from .model import PrimLib, Graph, Tensor | ||||
| use_poly_reduce = False | |||||
| class GraphSplitByPattern: | class GraphSplitByPattern: | ||||
| """Graph splitter""" | """Graph splitter""" | ||||
| @@ -24,14 +25,25 @@ class GraphSplitByPattern: | |||||
| MODE_BASIC = 1 | MODE_BASIC = 1 | ||||
| MODE_COMPOSITE = 2 | MODE_COMPOSITE = 2 | ||||
| def __init__(self, init_op): | |||||
| def __init__(self, init_op, is_output): | |||||
| self.pattern = PrimLib.iter_type(init_op) | self.pattern = PrimLib.iter_type(init_op) | ||||
| self.ops = [init_op] | self.ops = [init_op] | ||||
| self.in_relations = dict() # {area1: relation1, area2: relation2, ...} | self.in_relations = dict() # {area1: relation1, area2: relation2, ...} | ||||
| self.out_relations = dict() # {area1: relation1, area2: relation2, ...} | self.out_relations = dict() # {area1: relation1, area2: relation2, ...} | ||||
| self.mode = self.MODE_BASIC | self.mode = self.MODE_BASIC | ||||
| if self.pattern == PrimLib.TRANSFORM: | |||||
| if self.pattern == PrimLib.TRANSFORM or (use_poly_reduce and self.pattern == PrimLib.REDUCE): | |||||
| self.mode = self.MODE_COMPOSITE | self.mode = self.MODE_COMPOSITE | ||||
| self.is_output = is_output | |||||
| self.output_excluded = set() | |||||
| if self.pattern == PrimLib.REDUCE: | |||||
| def _gather_reduce_exclude(op): | |||||
| for to in op.output.to_ops: | |||||
| idx = to.inputs.index(op.output) | |||||
| if self.get_relation(to, idx) > PrimLib.ELEMWISE: | |||||
| self.output_excluded.add(to) | |||||
| else: | |||||
| _gather_reduce_exclude(to) | |||||
| _gather_reduce_exclude(init_op) | |||||
| def __str__(self): | def __str__(self): | ||||
| return '<' + '-'.join([op.output.name for op in self.ops]) + '>' | return '<' + '-'.join([op.output.name for op in self.ops]) + '>' | ||||
| @@ -39,18 +51,21 @@ class GraphSplitByPattern: | |||||
| def __repr__(self): | def __repr__(self): | ||||
| return str(self) | return str(self) | ||||
| def get_relation(self, op, i): | |||||
| relation = PrimLib.UNKNOWN | |||||
| _, elem_relation = PrimLib.input_relation(op, i) | |||||
| for r in elem_relation: | |||||
| if r is None: | |||||
| relation = max(relation, PrimLib.BROADCAST) | |||||
| elif r > relation: | |||||
| relation = r | |||||
| return relation | |||||
| def link_input(self, area_map): | def link_input(self, area_map): | ||||
| """Link inputs""" | """Link inputs""" | ||||
| def get_relation(op, i): | |||||
| relation = PrimLib.UNKNOWN | |||||
| _, elem_relation = PrimLib.input_relation(op, i) | |||||
| for r in elem_relation: | |||||
| if r is not None and r > relation: | |||||
| relation = r | |||||
| return relation | |||||
| for i, t in enumerate(self.ops[0].inputs): | for i, t in enumerate(self.ops[0].inputs): | ||||
| if t.op is not None: | if t.op is not None: | ||||
| area, relation = area_map[t.op], get_relation(self.ops[0], i) | |||||
| area, relation = area_map[t.op], self.get_relation(self.ops[0], i) | |||||
| self.in_relations[area] = relation | self.in_relations[area] = relation | ||||
| def link_output(self): | def link_output(self): | ||||
| @@ -79,7 +94,10 @@ class GraphSplitByPattern: | |||||
| r = rels.pop(area) | r = rels.pop(area) | ||||
| _update_relation(rels, self, r) | _update_relation(rels, self, r) | ||||
| self.ops.extend(area.ops) | |||||
| if self.pattern >= area.pattern: | |||||
| self.ops.extend(area.ops) | |||||
| else: | |||||
| self.ops = area.ops + self.ops | |||||
| _update_pattern() | _update_pattern() | ||||
| _fuse_relation(self.in_relations, area.in_relations) | _fuse_relation(self.in_relations, area.in_relations) | ||||
| _fuse_relation(self.out_relations, area.out_relations) | _fuse_relation(self.out_relations, area.out_relations) | ||||
| @@ -89,6 +107,10 @@ class GraphSplitByPattern: | |||||
| _redirect_relation(a.in_relations) | _redirect_relation(a.in_relations) | ||||
| if self.pattern > PrimLib.RESHAPE: | if self.pattern > PrimLib.RESHAPE: | ||||
| self.mode = self.MODE_COMPOSITE | self.mode = self.MODE_COMPOSITE | ||||
| if area.is_output and not self.is_output: | |||||
| self.is_output = True | |||||
| if area.output_excluded: | |||||
| self.output_excluded.update(area.output_excluded) | |||||
| def check_circle(self, to): | def check_circle(self, to): | ||||
| """Check circle. It returns false if circle exists""" | """Check circle. It returns false if circle exists""" | ||||
| @@ -102,15 +124,27 @@ class GraphSplitByPattern: | |||||
| return False | return False | ||||
| return True | return True | ||||
| BORADCAST_FUSE_DEPTH = 3 | |||||
| REDUCE_FUSE_DEPTH = 3 | |||||
| def dom_op(self): | |||||
| return self.ops[0] | |||||
| def reduce_out_exclude(self, area): | |||||
| if self.output_excluded: | |||||
| for op in self.output_excluded: | |||||
| if op in area.ops: | |||||
| return True | |||||
| return False | |||||
| BORADCAST_FUSE_DEPTH = 20 | |||||
| REDUCE_FUSE_DEPTH = 20 | |||||
| def __init__(self, graph): | def __init__(self, graph): | ||||
| self.graph = graph | self.graph = graph | ||||
| self.areas = [] | self.areas = [] | ||||
| area_map = {} | area_map = {} | ||||
| _, outputs = graph.deduce_parameters() | |||||
| for op in graph.ops: | for op in graph.ops: | ||||
| a = self.Area(op) | |||||
| is_output = op.output in outputs | |||||
| a = self.Area(op, is_output) | |||||
| self.areas.append(a) | self.areas.append(a) | ||||
| area_map[op] = a | area_map[op] = a | ||||
| for a in self.areas: | for a in self.areas: | ||||
| @@ -123,12 +157,20 @@ class GraphSplitByPattern: | |||||
| changed = False | changed = False | ||||
| while True: | while True: | ||||
| for dominant in self.areas: | for dominant in self.areas: | ||||
| fuse_areas = selector(dominant) | |||||
| if fuse_areas: | |||||
| for area in fuse_areas: | |||||
| changed = True | |||||
| dominant.fuse(area) | |||||
| self.areas.remove(area) | |||||
| result = selector(dominant) | |||||
| if result is not None and result[0]: | |||||
| fuse_areas, is_forward = result | |||||
| if is_forward: | |||||
| for area in fuse_areas: | |||||
| dominant.fuse(area) | |||||
| self.areas.remove(area) | |||||
| else: | |||||
| forward_area = dominant | |||||
| for area in fuse_areas: | |||||
| area.fuse(forward_area) | |||||
| self.areas.remove(forward_area) | |||||
| forward_area = area | |||||
| changed = True | |||||
| break | break | ||||
| else: | else: | ||||
| return changed | return changed | ||||
| @@ -148,43 +190,69 @@ class GraphSplitByPattern: | |||||
| def split(self): | def split(self): | ||||
| """Split graph by pattern""" | """Split graph by pattern""" | ||||
| def _reshape(dom): | |||||
| if dom.pattern != PrimLib.RESHAPE: | |||||
| return None | |||||
| min_area, forward_fuse = None, False | |||||
| for a, _ in dom.out_relations.items(): | |||||
| if a.pattern <= PrimLib.BROADCAST and dom.check_circle(a) and \ | |||||
| (min_area is None or a.pattern < min_area.pattern): | |||||
| min_area = a | |||||
| for a, _ in dom.in_relations.items(): | |||||
| if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom) and \ | |||||
| len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ | |||||
| (min_area is None or a.pattern < min_area.pattern): | |||||
| min_area, forward_fuse = a, True | |||||
| return ([min_area], forward_fuse) if min_area else None | |||||
| def _elemwise_depth(dom): | def _elemwise_depth(dom): | ||||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1: | if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1: | ||||
| return None | return None | ||||
| a, r = list(dom.in_relations.items())[0] | a, r = list(dom.in_relations.items())[0] | ||||
| if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 and r != PrimLib.ELEMWISE: | |||||
| if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r != PrimLib.ELEMWISE or \ | |||||
| a.dom_op().output.shape != dom.dom_op().output.shape: | |||||
| return None | return None | ||||
| return [a] | |||||
| return [a], True | |||||
| def _elemwise_width(dom): | def _elemwise_width(dom): | ||||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST): | if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST): | ||||
| return None | return None | ||||
| fused = [] | fused = [] | ||||
| for a, r in dom.in_relations.items(): | for a, r in dom.in_relations.items(): | ||||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom): | |||||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom) and \ | |||||
| a.dom_op().output.shape == dom.dom_op().output.shape: | |||||
| fused.append(a) | fused.append(a) | ||||
| return fused | |||||
| return fused, True | |||||
| def _broadcast_pat_exclude(dom, a, r): | |||||
| if use_poly_reduce and a.pattern == PrimLib.REDUCE: | |||||
| return dom.pattern > PrimLib.ELEMWISE or r > PrimLib.ELEMWISE | |||||
| return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST | |||||
| def _broadcast_depth(dom): | def _broadcast_depth(dom): | ||||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1: | |||||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \ | |||||
| dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH: | |||||
| return None | return None | ||||
| a, r = list(dom.in_relations.items())[0] | |||||
| if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \ | |||||
| r != PrimLib.BROADCAST or len(a.ops) > self.BORADCAST_FUSE_DEPTH: | |||||
| a, r = list(dom.out_relations.items())[0] | |||||
| if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1: | |||||
| return None | return None | ||||
| return [a] | |||||
| return [a], False | |||||
| def _broadcast_width(dom): | def _broadcast_width(dom): | ||||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST): | |||||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \ | |||||
| dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH: | |||||
| return None | return None | ||||
| fused = [] | fused = [] | ||||
| for a, r in dom.in_relations.items(): | |||||
| if a.pattern <= PrimLib.BROADCAST and r == PrimLib.BROADCAST and \ | |||||
| a.check_circle(dom) and len(a.ops) <= self.BORADCAST_FUSE_DEPTH: | |||||
| fused.append(a) | |||||
| return fused | |||||
| for a, r in dom.out_relations.items(): | |||||
| if _broadcast_pat_exclude(dom, a, r) or not dom.check_circle(a) or \ | |||||
| (fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): | |||||
| return None | |||||
| fused.append(a) | |||||
| return fused, False | |||||
| def _check_reduce_exclude(dom): | def _check_reduce_exclude(dom): | ||||
| if use_poly_reduce: | |||||
| return False | |||||
| # exclude large all-reduce | # exclude large all-reduce | ||||
| if len(dom.ops[0].inputs[0].shape) == len(dom.ops[0].attrs["reduce_axis"]) and \ | if len(dom.ops[0].inputs[0].shape) == len(dom.ops[0].attrs["reduce_axis"]) and \ | ||||
| dom.ops[0].inputs[0].get_size() > 10000: | dom.ops[0].inputs[0].get_size() > 10000: | ||||
| @@ -198,16 +266,22 @@ class GraphSplitByPattern: | |||||
| return True | return True | ||||
| return False | return False | ||||
| def _reduce_pat_exclude(dom, a, r): | |||||
| if len(a.ops) > self.REDUCE_FUSE_DEPTH: | |||||
| return True | |||||
| if use_poly_reduce: | |||||
| return a.pattern > PrimLib.ELEMWISE or r > PrimLib.REDUCE or r == PrimLib.BROADCAST | |||||
| return a.pattern > PrimLib.BROADCAST or r > PrimLib.REDUCE | |||||
| def _reduce_depth(dom): | def _reduce_depth(dom): | ||||
| if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1: | if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1: | ||||
| return None | return None | ||||
| if _check_reduce_exclude(dom): | if _check_reduce_exclude(dom): | ||||
| return None | return None | ||||
| a, r = list(dom.in_relations.items())[0] | a, r = list(dom.in_relations.items())[0] | ||||
| if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \ | |||||
| r > PrimLib.REDUCE or len(a.ops) > self.REDUCE_FUSE_DEPTH: | |||||
| if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1: | |||||
| return None | return None | ||||
| return [a] | |||||
| return [a], True | |||||
| def _reduce_width(dom): | def _reduce_width(dom): | ||||
| if dom.pattern != PrimLib.REDUCE: | if dom.pattern != PrimLib.REDUCE: | ||||
| @@ -216,18 +290,51 @@ class GraphSplitByPattern: | |||||
| return None | return None | ||||
| fused = [] | fused = [] | ||||
| for a, r in dom.in_relations.items(): | for a, r in dom.in_relations.items(): | ||||
| if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.REDUCE and \ | |||||
| a.check_circle(dom) and len(a.ops) <= self.REDUCE_FUSE_DEPTH: | |||||
| if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom): | |||||
| fused.append(a) | fused.append(a) | ||||
| return fused | |||||
| return fused, True | |||||
| def _tensor_size(tensor): | |||||
| size = 1 | |||||
| for i in tensor.shape: | |||||
| size *= i | |||||
| return size | |||||
| def _reduce_output(dom): | |||||
| if dom.pattern != PrimLib.REDUCE: | |||||
| return None | |||||
| is_all_reduce = _tensor_size(dom.ops[0].output) == 1 | |||||
| # excluded large size all reduce | |||||
| if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12: | |||||
| return None | |||||
| fused = [] | |||||
| for a, r in dom.out_relations.items(): | |||||
| if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ | |||||
| dom.check_circle(a) and not dom.reduce_out_exclude(a): | |||||
| fused.append(a) | |||||
| return fused, False | |||||
| def _transpose(dom): | |||||
| if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose": | |||||
| return None | |||||
| fused = [] | |||||
| for a, _ in dom.in_relations.items(): | |||||
| if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom): | |||||
| fused.append(a) | |||||
| return fused, True | |||||
| changed = True | changed = True | ||||
| while changed: | while changed: | ||||
| changed = self.fuse(_elemwise_depth) | |||||
| changed = self.fuse(_reshape) | |||||
| changed = self.fuse(_elemwise_depth) or changed | |||||
| changed = self.fuse(_elemwise_width) or changed | changed = self.fuse(_elemwise_width) or changed | ||||
| changed = self.fuse(_broadcast_depth) or changed | |||||
| changed = self.fuse(_broadcast_width) or changed | |||||
| changed = self.fuse(_reduce_depth) or changed | changed = self.fuse(_reduce_depth) or changed | ||||
| changed = self.fuse(_reduce_width) or changed | changed = self.fuse(_reduce_width) or changed | ||||
| changed = self.fuse(_broadcast_depth) or changed | |||||
| changed = self.fuse(_broadcast_width) or changed | |||||
| if use_poly_reduce: | |||||
| changed = self.fuse(_reduce_output) or changed | |||||
| self.fuse(_transpose) | |||||
| subgraphs, graphmodes = self.to_subgraphs() | subgraphs, graphmodes = self.to_subgraphs() | ||||
| return subgraphs, graphmodes | return subgraphs, graphmodes | ||||