diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index d6cee33c0c..48408859a9 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -16,6 +16,7 @@ from .model import PrimLib, Graph, Tensor +use_poly_reduce = False class GraphSplitByPattern: """Graph splitter""" @@ -24,14 +25,25 @@ class GraphSplitByPattern: MODE_BASIC = 1 MODE_COMPOSITE = 2 - def __init__(self, init_op): + def __init__(self, init_op, is_output): self.pattern = PrimLib.iter_type(init_op) self.ops = [init_op] self.in_relations = dict() # {area1: relation1, area2: relation2, ...} self.out_relations = dict() # {area1: relation1, area2: relation2, ...} self.mode = self.MODE_BASIC - if self.pattern == PrimLib.TRANSFORM: + if self.pattern == PrimLib.TRANSFORM or (use_poly_reduce and self.pattern == PrimLib.REDUCE): self.mode = self.MODE_COMPOSITE + self.is_output = is_output + self.output_excluded = set() + if self.pattern == PrimLib.REDUCE: + def _gather_reduce_exclude(op): + for to in op.output.to_ops: + idx = to.inputs.index(op.output) + if self.get_relation(to, idx) > PrimLib.ELEMWISE: + self.output_excluded.add(to) + else: + _gather_reduce_exclude(to) + _gather_reduce_exclude(init_op) def __str__(self): return '<' + '-'.join([op.output.name for op in self.ops]) + '>' @@ -39,18 +51,21 @@ class GraphSplitByPattern: def __repr__(self): return str(self) + def get_relation(self, op, i): + relation = PrimLib.UNKNOWN + _, elem_relation = PrimLib.input_relation(op, i) + for r in elem_relation: + if r is None: + relation = max(relation, PrimLib.BROADCAST) + elif r > relation: + relation = r + return relation + def link_input(self, area_map): """Link inputs""" - def get_relation(op, i): - relation = PrimLib.UNKNOWN - _, elem_relation = PrimLib.input_relation(op, i) - for r in elem_relation: - if r is not None and r > relation: - relation = r - return relation for i, t in enumerate(self.ops[0].inputs): if t.op is not None: - area, relation = area_map[t.op], get_relation(self.ops[0], i) + area, relation = area_map[t.op], self.get_relation(self.ops[0], i) self.in_relations[area] = relation def link_output(self): @@ -79,7 +94,10 @@ class GraphSplitByPattern: r = rels.pop(area) _update_relation(rels, self, r) - self.ops.extend(area.ops) + if self.pattern >= area.pattern: + self.ops.extend(area.ops) + else: + self.ops = area.ops + self.ops _update_pattern() _fuse_relation(self.in_relations, area.in_relations) _fuse_relation(self.out_relations, area.out_relations) @@ -89,6 +107,10 @@ class GraphSplitByPattern: _redirect_relation(a.in_relations) if self.pattern > PrimLib.RESHAPE: self.mode = self.MODE_COMPOSITE + if area.is_output and not self.is_output: + self.is_output = True + if area.output_excluded: + self.output_excluded.update(area.output_excluded) def check_circle(self, to): """Check circle. It returns false if circle exists""" @@ -102,15 +124,27 @@ class GraphSplitByPattern: return False return True - BORADCAST_FUSE_DEPTH = 3 - REDUCE_FUSE_DEPTH = 3 + def dom_op(self): + return self.ops[0] + + def reduce_out_exclude(self, area): + if self.output_excluded: + for op in self.output_excluded: + if op in area.ops: + return True + return False + + BORADCAST_FUSE_DEPTH = 20 + REDUCE_FUSE_DEPTH = 20 def __init__(self, graph): self.graph = graph self.areas = [] area_map = {} + _, outputs = graph.deduce_parameters() for op in graph.ops: - a = self.Area(op) + is_output = op.output in outputs + a = self.Area(op, is_output) self.areas.append(a) area_map[op] = a for a in self.areas: @@ -123,12 +157,20 @@ class GraphSplitByPattern: changed = False while True: for dominant in self.areas: - fuse_areas = selector(dominant) - if fuse_areas: - for area in fuse_areas: - changed = True - dominant.fuse(area) - self.areas.remove(area) + result = selector(dominant) + if result is not None and result[0]: + fuse_areas, is_forward = result + if is_forward: + for area in fuse_areas: + dominant.fuse(area) + self.areas.remove(area) + else: + forward_area = dominant + for area in fuse_areas: + area.fuse(forward_area) + self.areas.remove(forward_area) + forward_area = area + changed = True break else: return changed @@ -148,43 +190,69 @@ class GraphSplitByPattern: def split(self): """Split graph by pattern""" + def _reshape(dom): + if dom.pattern != PrimLib.RESHAPE: + return None + min_area, forward_fuse = None, False + for a, _ in dom.out_relations.items(): + if a.pattern <= PrimLib.BROADCAST and dom.check_circle(a) and \ + (min_area is None or a.pattern < min_area.pattern): + min_area = a + for a, _ in dom.in_relations.items(): + if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom) and \ + len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ + (min_area is None or a.pattern < min_area.pattern): + min_area, forward_fuse = a, True + return ([min_area], forward_fuse) if min_area else None + def _elemwise_depth(dom): if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1: return None a, r = list(dom.in_relations.items())[0] - if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 and r != PrimLib.ELEMWISE: + if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r != PrimLib.ELEMWISE or \ + a.dom_op().output.shape != dom.dom_op().output.shape: return None - return [a] + return [a], True def _elemwise_width(dom): if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST): return None fused = [] for a, r in dom.in_relations.items(): - if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom): + if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom) and \ + a.dom_op().output.shape == dom.dom_op().output.shape: fused.append(a) - return fused + return fused, True + + def _broadcast_pat_exclude(dom, a, r): + if use_poly_reduce and a.pattern == PrimLib.REDUCE: + return dom.pattern > PrimLib.ELEMWISE or r > PrimLib.ELEMWISE + return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST def _broadcast_depth(dom): - if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1: + if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \ + dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH: return None - a, r = list(dom.in_relations.items())[0] - if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \ - r != PrimLib.BROADCAST or len(a.ops) > self.BORADCAST_FUSE_DEPTH: + a, r = list(dom.out_relations.items())[0] + if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1: return None - return [a] + return [a], False def _broadcast_width(dom): - if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST): + if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \ + dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH: return None fused = [] - for a, r in dom.in_relations.items(): - if a.pattern <= PrimLib.BROADCAST and r == PrimLib.BROADCAST and \ - a.check_circle(dom) and len(a.ops) <= self.BORADCAST_FUSE_DEPTH: - fused.append(a) - return fused + for a, r in dom.out_relations.items(): + if _broadcast_pat_exclude(dom, a, r) or not dom.check_circle(a) or \ + (fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): + return None + fused.append(a) + return fused, False def _check_reduce_exclude(dom): + if use_poly_reduce: + return False # exclude large all-reduce if len(dom.ops[0].inputs[0].shape) == len(dom.ops[0].attrs["reduce_axis"]) and \ dom.ops[0].inputs[0].get_size() > 10000: @@ -198,16 +266,22 @@ class GraphSplitByPattern: return True return False + def _reduce_pat_exclude(dom, a, r): + if len(a.ops) > self.REDUCE_FUSE_DEPTH: + return True + if use_poly_reduce: + return a.pattern > PrimLib.ELEMWISE or r > PrimLib.REDUCE or r == PrimLib.BROADCAST + return a.pattern > PrimLib.BROADCAST or r > PrimLib.REDUCE + def _reduce_depth(dom): if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1: return None if _check_reduce_exclude(dom): return None a, r = list(dom.in_relations.items())[0] - if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \ - r > PrimLib.REDUCE or len(a.ops) > self.REDUCE_FUSE_DEPTH: + if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1: return None - return [a] + return [a], True def _reduce_width(dom): if dom.pattern != PrimLib.REDUCE: @@ -216,18 +290,51 @@ class GraphSplitByPattern: return None fused = [] for a, r in dom.in_relations.items(): - if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.REDUCE and \ - a.check_circle(dom) and len(a.ops) <= self.REDUCE_FUSE_DEPTH: + if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom): fused.append(a) - return fused + return fused, True + + def _tensor_size(tensor): + size = 1 + for i in tensor.shape: + size *= i + return size + + def _reduce_output(dom): + if dom.pattern != PrimLib.REDUCE: + return None + is_all_reduce = _tensor_size(dom.ops[0].output) == 1 + # excluded large size all reduce + if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12: + return None + fused = [] + for a, r in dom.out_relations.items(): + if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \ + dom.check_circle(a) and not dom.reduce_out_exclude(a): + fused.append(a) + return fused, False + + def _transpose(dom): + if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose": + return None + fused = [] + for a, _ in dom.in_relations.items(): + if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom): + fused.append(a) + return fused, True + changed = True while changed: - changed = self.fuse(_elemwise_depth) + changed = self.fuse(_reshape) + changed = self.fuse(_elemwise_depth) or changed changed = self.fuse(_elemwise_width) or changed - changed = self.fuse(_broadcast_depth) or changed - changed = self.fuse(_broadcast_width) or changed changed = self.fuse(_reduce_depth) or changed changed = self.fuse(_reduce_width) or changed + changed = self.fuse(_broadcast_depth) or changed + changed = self.fuse(_broadcast_width) or changed + if use_poly_reduce: + changed = self.fuse(_reduce_output) or changed + self.fuse(_transpose) subgraphs, graphmodes = self.to_subgraphs() return subgraphs, graphmodes