From 32e19e83da2ba03c547df39863aca111aed503e4 Mon Sep 17 00:00:00 2001 From: Gaoxiong Date: Thu, 31 Dec 2020 20:25:45 +0800 Subject: [PATCH] update graph kernel split model for Ascend --- .../graph_kernel/model/graph_split.py | 218 ++++++++++++++---- mindspore/_extends/graph_kernel/splitter.py | 3 +- 2 files changed, 170 insertions(+), 51 deletions(-) diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index 34ea68babf..e535083b39 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -16,9 +16,6 @@ from functools import reduce from .model import PrimLib, Graph, Tensor -use_poly_reduce = True - - class GraphSplitByPattern: """Graph splitter""" class Area: @@ -32,7 +29,6 @@ class GraphSplitByPattern: self.in_relations = dict() # {area1: relation1, area2: relation2, ...} self.out_relations = dict() # {area1: relation1, area2: relation2, ...} self.mode = None - self.set_default_mode() self.is_output = is_output self.output_excluded = set() if self.pattern == PrimLib.REDUCE: @@ -51,17 +47,6 @@ class GraphSplitByPattern: def __repr__(self): return str(self) - def set_default_mode(self): - def _get_default_mode(op): - if op.prim == "AddN": - return self.MODE_COMPOSITE - pattern = PrimLib.iter_type(op) - if pattern == PrimLib.TRANSFORM or pattern == PrimLib.BROADCAST or \ - (use_poly_reduce and pattern == PrimLib.REDUCE): - return self.MODE_COMPOSITE - return self.MODE_BASIC - self.mode = _get_default_mode(self.ops[0]) - def get_relation(self, op, i): relation = PrimLib.UNKNOWN _, elem_relation = PrimLib.input_relation(op, i) @@ -145,9 +130,6 @@ class GraphSplitByPattern: return True return False - BORADCAST_FUSE_DEPTH = 20 - REDUCE_FUSE_DEPTH = 20 - def __init__(self, graph): self.graph = graph self.areas = [] @@ -156,6 +138,7 @@ class GraphSplitByPattern: for op in graph.ops: is_output = op.output in outputs a = self.Area(op, is_output) + self.set_default_mode(a) self.areas.append(a) area_map[op] = a for a in self.areas: @@ -163,6 +146,9 @@ class GraphSplitByPattern: for a in self.areas: a.link_output() + def set_default_mode(self, area): + area.mode = self.get_default_mode(area.ops[0]) + def fuse(self, selector): """Fuse areas""" changed = False @@ -200,6 +186,51 @@ class GraphSplitByPattern: return subgraphs, graphmodes def split(self): + """Split graph by pattern""" + self.do_split() + # The reshape should not be output node + # Note: after this function, the input output relation is not maintained. + self.split_output_reshapes() + subgraphs, graphmodes = self.to_subgraphs() + return subgraphs, graphmodes + + def split_output_reshapes(self): + """Force split the output reshapes into other new """ + new_areas = [] + for area in self.areas: + out_reshape_ops = [op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE] + remain_ops = [op for op in area.ops if op not in out_reshape_ops] + if not remain_ops or not out_reshape_ops: + continue + changed = True + while changed: + changed = False + for op in out_reshape_ops: + if any([to_op in remain_ops for to_op in op.output.to_ops]): + out_reshape_ops.remove(op) + remain_ops.append(op) + changed = True + break + if out_reshape_ops: + for op in out_reshape_ops: + new_areas.append(self.Area(op, False)) + area.ops = remain_ops + if len(remain_ops) == 1: + self.set_default_mode(area) + if new_areas: + self.areas += new_areas + +use_poly_reduce = True +class GraphSplitGpu(GraphSplitByPattern): + """Graph splitter""" + BORADCAST_FUSE_DEPTH = 20 + REDUCE_FUSE_DEPTH = 20 + + def get_default_mode(self, op): + pattern = PrimLib.iter_type(op) + return self.Area.MODE_BASIC if pattern == PrimLib.RESHAPE else self.Area.MODE_COMPOSITE + + def do_split(self): """Split graph by pattern""" def _reshape(dom): if dom.pattern != PrimLib.RESHAPE: @@ -367,40 +398,127 @@ class GraphSplitByPattern: changed = self.fuse(_reduce_output) or changed self.fuse(_transpose) - # The reshape should not be output node - # Note: after this function, the input output relation is not maintained. - self.split_output_reshapes() +class GraphSplitAscend(GraphSplitByPattern): + """Graph splitter""" + BORADCAST_FUSE_DEPTH = 6 + REDUCE_FUSE_DEPTH = 10 - subgraphs, graphmodes = self.to_subgraphs() - return subgraphs, graphmodes + def get_default_mode(self, op): + if op.prim in ("Tile", "BroadcastTo"): + return self.Area.MODE_COMPOSITE + return self.Area.MODE_BASIC - def split_output_reshapes(self): - """Force split the output reshapes into other new """ - new_areas = [] - for area in self.areas: - out_reshape_ops = [op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE] - remain_ops = [op for op in area.ops if op not in out_reshape_ops] - if not remain_ops or not out_reshape_ops: - continue - changed = True - while changed: - changed = False - for op in out_reshape_ops: - if any([to_op in remain_ops for to_op in op.output.to_ops]): - out_reshape_ops.remove(op) - remain_ops.append(op) - changed = True - break - if out_reshape_ops: - for op in out_reshape_ops: - new_areas.append(self.Area(op, False)) - area.ops = remain_ops - if len(remain_ops) == 1: - area.set_default_mode() - if new_areas: - self.areas += new_areas + def do_split(self): + """Split graph by pattern""" + def _tensor_size(tensor): + size = 1 + for i in tensor.shape: + size *= i + return size + + def _likely_multicore(dom): + op = dom.dom_op() + iter_size = _tensor_size(op.output if not PrimLib.is_reduce(op) else op.inputs[0]) + return iter_size > 1024 + def _reshape(dom): + if dom.pattern != PrimLib.RESHAPE: + return None + min_area, forward_fuse = None, False + for a, _ in dom.out_relations.items(): + if a.pattern <= PrimLib.BROADCAST and dom.check_circle(a) and \ + (min_area is None or a.pattern < min_area.pattern): + min_area = a + for a, _ in dom.in_relations.items(): + if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom) and \ + len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ + (min_area is None or a.pattern < min_area.pattern): + min_area, forward_fuse = a, True + return ([min_area], forward_fuse) if min_area else None + + def _elemwise_depth(dom): + if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1: + return None + a, r = list(dom.in_relations.items())[0] + if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r != PrimLib.ELEMWISE or \ + a.dom_op().output.shape != dom.dom_op().output.shape: + return None + return [a], True + + def _elemwise_width(dom): + if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST): + return None + fused = [] + for a, r in dom.in_relations.items(): + if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom) and \ + a.dom_op().output.shape == dom.dom_op().output.shape: + fused.append(a) + return fused, True + + def _broadcast_pat_exclude(dom, a, r): + if _likely_multicore(a) and (dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH): + return True + return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST + + def _broadcast_depth(dom): + if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1: + return None + a, r = list(dom.out_relations.items())[0] + if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1: + return None + return [a], False + + def _broadcast_width(dom): + if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST): + return None + fused = [] + for a, r in dom.out_relations.items(): + if _broadcast_pat_exclude(dom, a, r) or not dom.check_circle(a) or \ + (fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): + return None + fused.append(a) + return fused, False + + def _reduce_pat_exclude(dom, a, r): + if len(a.ops) > self.REDUCE_FUSE_DEPTH: + return True + if r == PrimLib.BROADCAST and _likely_multicore(dom) and \ + (dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH): + return True + return a.pattern > PrimLib.BROADCAST or r > PrimLib.REDUCE + + def _reduce_depth(dom): + if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1: + return None + a, r = list(dom.in_relations.items())[0] + if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1: + return None + return [a], True + + def _reduce_width(dom): + if dom.pattern != PrimLib.REDUCE: + return None + fused = [] + for a, r in dom.in_relations.items(): + if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom): + fused.append(a) + return fused, True + + changed = True + while changed: + changed = self.fuse(_reshape) + changed = self.fuse(_elemwise_depth) or changed + changed = self.fuse(_elemwise_width) or changed + changed = self.fuse(_reduce_depth) or changed + changed = self.fuse(_reduce_width) or changed + changed = self.fuse(_broadcast_depth) or changed + changed = self.fuse(_broadcast_width) or changed -def split(graph): +def split(graph, target): """Split graph""" - return GraphSplitByPattern(graph).split() + result = None + if target == "cuda": + result = GraphSplitGpu(graph).split() + else: + result = GraphSplitAscend(graph).split() + return result diff --git a/mindspore/_extends/graph_kernel/splitter.py b/mindspore/_extends/graph_kernel/splitter.py index bcb74ac718..0a6535b4b0 100644 --- a/mindspore/_extends/graph_kernel/splitter.py +++ b/mindspore/_extends/graph_kernel/splitter.py @@ -25,8 +25,9 @@ def split_with_json(json_str: str): """Call costmodel to split GraphKernel""" try: graph_desc = json.loads(json_str) + target = graph_desc['process'] comp = model.load_composite(graph_desc) - graph_split, graph_mode = model.split(comp.graph) + graph_split, graph_mode = model.split(comp.graph, target) is_multi_graph = len(graph_split) > 1 graph_list = list(map(comp.dump, graph_split)) result = {"multi_graph": is_multi_graph,