|
|
|
@@ -16,9 +16,6 @@ |
|
|
|
from functools import reduce |
|
|
|
from .model import PrimLib, Graph, Tensor |
|
|
|
|
|
|
|
use_poly_reduce = True |
|
|
|
|
|
|
|
|
|
|
|
class GraphSplitByPattern: |
|
|
|
"""Graph splitter""" |
|
|
|
class Area: |
|
|
|
@@ -32,7 +29,6 @@ class GraphSplitByPattern: |
|
|
|
self.in_relations = dict() # {area1: relation1, area2: relation2, ...} |
|
|
|
self.out_relations = dict() # {area1: relation1, area2: relation2, ...} |
|
|
|
self.mode = None |
|
|
|
self.set_default_mode() |
|
|
|
self.is_output = is_output |
|
|
|
self.output_excluded = set() |
|
|
|
if self.pattern == PrimLib.REDUCE: |
|
|
|
@@ -51,17 +47,6 @@ class GraphSplitByPattern: |
|
|
|
def __repr__(self): |
|
|
|
return str(self) |
|
|
|
|
|
|
|
def set_default_mode(self): |
|
|
|
def _get_default_mode(op): |
|
|
|
if op.prim == "AddN": |
|
|
|
return self.MODE_COMPOSITE |
|
|
|
pattern = PrimLib.iter_type(op) |
|
|
|
if pattern == PrimLib.TRANSFORM or pattern == PrimLib.BROADCAST or \ |
|
|
|
(use_poly_reduce and pattern == PrimLib.REDUCE): |
|
|
|
return self.MODE_COMPOSITE |
|
|
|
return self.MODE_BASIC |
|
|
|
self.mode = _get_default_mode(self.ops[0]) |
|
|
|
|
|
|
|
def get_relation(self, op, i): |
|
|
|
relation = PrimLib.UNKNOWN |
|
|
|
_, elem_relation = PrimLib.input_relation(op, i) |
|
|
|
@@ -145,9 +130,6 @@ class GraphSplitByPattern: |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
BORADCAST_FUSE_DEPTH = 20 |
|
|
|
REDUCE_FUSE_DEPTH = 20 |
|
|
|
|
|
|
|
def __init__(self, graph): |
|
|
|
self.graph = graph |
|
|
|
self.areas = [] |
|
|
|
@@ -156,6 +138,7 @@ class GraphSplitByPattern: |
|
|
|
for op in graph.ops: |
|
|
|
is_output = op.output in outputs |
|
|
|
a = self.Area(op, is_output) |
|
|
|
self.set_default_mode(a) |
|
|
|
self.areas.append(a) |
|
|
|
area_map[op] = a |
|
|
|
for a in self.areas: |
|
|
|
@@ -163,6 +146,9 @@ class GraphSplitByPattern: |
|
|
|
for a in self.areas: |
|
|
|
a.link_output() |
|
|
|
|
|
|
|
def set_default_mode(self, area): |
|
|
|
area.mode = self.get_default_mode(area.ops[0]) |
|
|
|
|
|
|
|
def fuse(self, selector): |
|
|
|
"""Fuse areas""" |
|
|
|
changed = False |
|
|
|
@@ -200,6 +186,51 @@ class GraphSplitByPattern: |
|
|
|
return subgraphs, graphmodes |
|
|
|
|
|
|
|
def split(self): |
|
|
|
"""Split graph by pattern""" |
|
|
|
self.do_split() |
|
|
|
# The reshape should not be output node |
|
|
|
# Note: after this function, the input output relation is not maintained. |
|
|
|
self.split_output_reshapes() |
|
|
|
subgraphs, graphmodes = self.to_subgraphs() |
|
|
|
return subgraphs, graphmodes |
|
|
|
|
|
|
|
def split_output_reshapes(self): |
|
|
|
"""Force split the output reshapes into other new """ |
|
|
|
new_areas = [] |
|
|
|
for area in self.areas: |
|
|
|
out_reshape_ops = [op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE] |
|
|
|
remain_ops = [op for op in area.ops if op not in out_reshape_ops] |
|
|
|
if not remain_ops or not out_reshape_ops: |
|
|
|
continue |
|
|
|
changed = True |
|
|
|
while changed: |
|
|
|
changed = False |
|
|
|
for op in out_reshape_ops: |
|
|
|
if any([to_op in remain_ops for to_op in op.output.to_ops]): |
|
|
|
out_reshape_ops.remove(op) |
|
|
|
remain_ops.append(op) |
|
|
|
changed = True |
|
|
|
break |
|
|
|
if out_reshape_ops: |
|
|
|
for op in out_reshape_ops: |
|
|
|
new_areas.append(self.Area(op, False)) |
|
|
|
area.ops = remain_ops |
|
|
|
if len(remain_ops) == 1: |
|
|
|
self.set_default_mode(area) |
|
|
|
if new_areas: |
|
|
|
self.areas += new_areas |
|
|
|
|
|
|
|
use_poly_reduce = True |
|
|
|
class GraphSplitGpu(GraphSplitByPattern): |
|
|
|
"""Graph splitter""" |
|
|
|
BORADCAST_FUSE_DEPTH = 20 |
|
|
|
REDUCE_FUSE_DEPTH = 20 |
|
|
|
|
|
|
|
def get_default_mode(self, op): |
|
|
|
pattern = PrimLib.iter_type(op) |
|
|
|
return self.Area.MODE_BASIC if pattern == PrimLib.RESHAPE else self.Area.MODE_COMPOSITE |
|
|
|
|
|
|
|
def do_split(self): |
|
|
|
"""Split graph by pattern""" |
|
|
|
def _reshape(dom): |
|
|
|
if dom.pattern != PrimLib.RESHAPE: |
|
|
|
@@ -367,40 +398,127 @@ class GraphSplitByPattern: |
|
|
|
changed = self.fuse(_reduce_output) or changed |
|
|
|
self.fuse(_transpose) |
|
|
|
|
|
|
|
# The reshape should not be output node |
|
|
|
# Note: after this function, the input output relation is not maintained. |
|
|
|
self.split_output_reshapes() |
|
|
|
class GraphSplitAscend(GraphSplitByPattern): |
|
|
|
"""Graph splitter""" |
|
|
|
BORADCAST_FUSE_DEPTH = 6 |
|
|
|
REDUCE_FUSE_DEPTH = 10 |
|
|
|
|
|
|
|
subgraphs, graphmodes = self.to_subgraphs() |
|
|
|
return subgraphs, graphmodes |
|
|
|
def get_default_mode(self, op): |
|
|
|
if op.prim in ("Tile", "BroadcastTo"): |
|
|
|
return self.Area.MODE_COMPOSITE |
|
|
|
return self.Area.MODE_BASIC |
|
|
|
|
|
|
|
def split_output_reshapes(self): |
|
|
|
"""Force split the output reshapes into other new """ |
|
|
|
new_areas = [] |
|
|
|
for area in self.areas: |
|
|
|
out_reshape_ops = [op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE] |
|
|
|
remain_ops = [op for op in area.ops if op not in out_reshape_ops] |
|
|
|
if not remain_ops or not out_reshape_ops: |
|
|
|
continue |
|
|
|
changed = True |
|
|
|
while changed: |
|
|
|
changed = False |
|
|
|
for op in out_reshape_ops: |
|
|
|
if any([to_op in remain_ops for to_op in op.output.to_ops]): |
|
|
|
out_reshape_ops.remove(op) |
|
|
|
remain_ops.append(op) |
|
|
|
changed = True |
|
|
|
break |
|
|
|
if out_reshape_ops: |
|
|
|
for op in out_reshape_ops: |
|
|
|
new_areas.append(self.Area(op, False)) |
|
|
|
area.ops = remain_ops |
|
|
|
if len(remain_ops) == 1: |
|
|
|
area.set_default_mode() |
|
|
|
if new_areas: |
|
|
|
self.areas += new_areas |
|
|
|
def do_split(self): |
|
|
|
"""Split graph by pattern""" |
|
|
|
def _tensor_size(tensor): |
|
|
|
size = 1 |
|
|
|
for i in tensor.shape: |
|
|
|
size *= i |
|
|
|
return size |
|
|
|
|
|
|
|
def _likely_multicore(dom): |
|
|
|
op = dom.dom_op() |
|
|
|
iter_size = _tensor_size(op.output if not PrimLib.is_reduce(op) else op.inputs[0]) |
|
|
|
return iter_size > 1024 |
|
|
|
|
|
|
|
def _reshape(dom): |
|
|
|
if dom.pattern != PrimLib.RESHAPE: |
|
|
|
return None |
|
|
|
min_area, forward_fuse = None, False |
|
|
|
for a, _ in dom.out_relations.items(): |
|
|
|
if a.pattern <= PrimLib.BROADCAST and dom.check_circle(a) and \ |
|
|
|
(min_area is None or a.pattern < min_area.pattern): |
|
|
|
min_area = a |
|
|
|
for a, _ in dom.in_relations.items(): |
|
|
|
if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom) and \ |
|
|
|
len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \ |
|
|
|
(min_area is None or a.pattern < min_area.pattern): |
|
|
|
min_area, forward_fuse = a, True |
|
|
|
return ([min_area], forward_fuse) if min_area else None |
|
|
|
|
|
|
|
def _elemwise_depth(dom): |
|
|
|
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1: |
|
|
|
return None |
|
|
|
a, r = list(dom.in_relations.items())[0] |
|
|
|
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r != PrimLib.ELEMWISE or \ |
|
|
|
a.dom_op().output.shape != dom.dom_op().output.shape: |
|
|
|
return None |
|
|
|
return [a], True |
|
|
|
|
|
|
|
def _elemwise_width(dom): |
|
|
|
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST): |
|
|
|
return None |
|
|
|
fused = [] |
|
|
|
for a, r in dom.in_relations.items(): |
|
|
|
if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom) and \ |
|
|
|
a.dom_op().output.shape == dom.dom_op().output.shape: |
|
|
|
fused.append(a) |
|
|
|
return fused, True |
|
|
|
|
|
|
|
def _broadcast_pat_exclude(dom, a, r): |
|
|
|
if _likely_multicore(a) and (dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH): |
|
|
|
return True |
|
|
|
return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST |
|
|
|
|
|
|
|
def _broadcast_depth(dom): |
|
|
|
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1: |
|
|
|
return None |
|
|
|
a, r = list(dom.out_relations.items())[0] |
|
|
|
if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1: |
|
|
|
return None |
|
|
|
return [a], False |
|
|
|
|
|
|
|
def _broadcast_width(dom): |
|
|
|
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST): |
|
|
|
return None |
|
|
|
fused = [] |
|
|
|
for a, r in dom.out_relations.items(): |
|
|
|
if _broadcast_pat_exclude(dom, a, r) or not dom.check_circle(a) or \ |
|
|
|
(fused and fused[0].dom_op().output.shape != a.dom_op().output.shape): |
|
|
|
return None |
|
|
|
fused.append(a) |
|
|
|
return fused, False |
|
|
|
|
|
|
|
def _reduce_pat_exclude(dom, a, r): |
|
|
|
if len(a.ops) > self.REDUCE_FUSE_DEPTH: |
|
|
|
return True |
|
|
|
if r == PrimLib.BROADCAST and _likely_multicore(dom) and \ |
|
|
|
(dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH): |
|
|
|
return True |
|
|
|
return a.pattern > PrimLib.BROADCAST or r > PrimLib.REDUCE |
|
|
|
|
|
|
|
def _reduce_depth(dom): |
|
|
|
if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1: |
|
|
|
return None |
|
|
|
a, r = list(dom.in_relations.items())[0] |
|
|
|
if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1: |
|
|
|
return None |
|
|
|
return [a], True |
|
|
|
|
|
|
|
def _reduce_width(dom): |
|
|
|
if dom.pattern != PrimLib.REDUCE: |
|
|
|
return None |
|
|
|
fused = [] |
|
|
|
for a, r in dom.in_relations.items(): |
|
|
|
if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom): |
|
|
|
fused.append(a) |
|
|
|
return fused, True |
|
|
|
|
|
|
|
changed = True |
|
|
|
while changed: |
|
|
|
changed = self.fuse(_reshape) |
|
|
|
changed = self.fuse(_elemwise_depth) or changed |
|
|
|
changed = self.fuse(_elemwise_width) or changed |
|
|
|
changed = self.fuse(_reduce_depth) or changed |
|
|
|
changed = self.fuse(_reduce_width) or changed |
|
|
|
changed = self.fuse(_broadcast_depth) or changed |
|
|
|
changed = self.fuse(_broadcast_width) or changed |
|
|
|
|
|
|
|
def split(graph): |
|
|
|
def split(graph, target): |
|
|
|
"""Split graph""" |
|
|
|
return GraphSplitByPattern(graph).split() |
|
|
|
result = None |
|
|
|
if target == "cuda": |
|
|
|
result = GraphSplitGpu(graph).split() |
|
|
|
else: |
|
|
|
result = GraphSplitAscend(graph).split() |
|
|
|
return result |