Browse Source

update graph kernel split model

tags/v1.1.0
Gaoxiong 5 years ago
parent
commit
e4c3d3e0e9
1 changed files with 151 additions and 44 deletions
  1. +151
    -44
      mindspore/_extends/graph_kernel/model/graph_split.py

+ 151
- 44
mindspore/_extends/graph_kernel/model/graph_split.py View File

@@ -16,6 +16,7 @@

from .model import PrimLib, Graph, Tensor

use_poly_reduce = False

class GraphSplitByPattern:
"""Graph splitter"""
@@ -24,14 +25,25 @@ class GraphSplitByPattern:
MODE_BASIC = 1
MODE_COMPOSITE = 2

def __init__(self, init_op):
def __init__(self, init_op, is_output):
self.pattern = PrimLib.iter_type(init_op)
self.ops = [init_op]
self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
self.out_relations = dict() # {area1: relation1, area2: relation2, ...}
self.mode = self.MODE_BASIC
if self.pattern == PrimLib.TRANSFORM:
if self.pattern == PrimLib.TRANSFORM or (use_poly_reduce and self.pattern == PrimLib.REDUCE):
self.mode = self.MODE_COMPOSITE
self.is_output = is_output
self.output_excluded = set()
if self.pattern == PrimLib.REDUCE:
def _gather_reduce_exclude(op):
for to in op.output.to_ops:
idx = to.inputs.index(op.output)
if self.get_relation(to, idx) > PrimLib.ELEMWISE:
self.output_excluded.add(to)
else:
_gather_reduce_exclude(to)
_gather_reduce_exclude(init_op)

def __str__(self):
return '<' + '-'.join([op.output.name for op in self.ops]) + '>'
@@ -39,18 +51,21 @@ class GraphSplitByPattern:
def __repr__(self):
return str(self)

def get_relation(self, op, i):
relation = PrimLib.UNKNOWN
_, elem_relation = PrimLib.input_relation(op, i)
for r in elem_relation:
if r is None:
relation = max(relation, PrimLib.BROADCAST)
elif r > relation:
relation = r
return relation

def link_input(self, area_map):
"""Link inputs"""
def get_relation(op, i):
relation = PrimLib.UNKNOWN
_, elem_relation = PrimLib.input_relation(op, i)
for r in elem_relation:
if r is not None and r > relation:
relation = r
return relation
for i, t in enumerate(self.ops[0].inputs):
if t.op is not None:
area, relation = area_map[t.op], get_relation(self.ops[0], i)
area, relation = area_map[t.op], self.get_relation(self.ops[0], i)
self.in_relations[area] = relation

def link_output(self):
@@ -79,7 +94,10 @@ class GraphSplitByPattern:
r = rels.pop(area)
_update_relation(rels, self, r)

self.ops.extend(area.ops)
if self.pattern >= area.pattern:
self.ops.extend(area.ops)
else:
self.ops = area.ops + self.ops
_update_pattern()
_fuse_relation(self.in_relations, area.in_relations)
_fuse_relation(self.out_relations, area.out_relations)
@@ -89,6 +107,10 @@ class GraphSplitByPattern:
_redirect_relation(a.in_relations)
if self.pattern > PrimLib.RESHAPE:
self.mode = self.MODE_COMPOSITE
if area.is_output and not self.is_output:
self.is_output = True
if area.output_excluded:
self.output_excluded.update(area.output_excluded)

def check_circle(self, to):
"""Check circle. It returns false if circle exists"""
@@ -102,15 +124,27 @@ class GraphSplitByPattern:
return False
return True

BORADCAST_FUSE_DEPTH = 3
REDUCE_FUSE_DEPTH = 3
def dom_op(self):
return self.ops[0]

def reduce_out_exclude(self, area):
if self.output_excluded:
for op in self.output_excluded:
if op in area.ops:
return True
return False

BORADCAST_FUSE_DEPTH = 20
REDUCE_FUSE_DEPTH = 20

def __init__(self, graph):
self.graph = graph
self.areas = []
area_map = {}
_, outputs = graph.deduce_parameters()
for op in graph.ops:
a = self.Area(op)
is_output = op.output in outputs
a = self.Area(op, is_output)
self.areas.append(a)
area_map[op] = a
for a in self.areas:
@@ -123,12 +157,20 @@ class GraphSplitByPattern:
changed = False
while True:
for dominant in self.areas:
fuse_areas = selector(dominant)
if fuse_areas:
for area in fuse_areas:
changed = True
dominant.fuse(area)
self.areas.remove(area)
result = selector(dominant)
if result is not None and result[0]:
fuse_areas, is_forward = result
if is_forward:
for area in fuse_areas:
dominant.fuse(area)
self.areas.remove(area)
else:
forward_area = dominant
for area in fuse_areas:
area.fuse(forward_area)
self.areas.remove(forward_area)
forward_area = area
changed = True
break
else:
return changed
@@ -148,43 +190,69 @@ class GraphSplitByPattern:

def split(self):
"""Split graph by pattern"""
def _reshape(dom):
if dom.pattern != PrimLib.RESHAPE:
return None
min_area, forward_fuse = None, False
for a, _ in dom.out_relations.items():
if a.pattern <= PrimLib.BROADCAST and dom.check_circle(a) and \
(min_area is None or a.pattern < min_area.pattern):
min_area = a
for a, _ in dom.in_relations.items():
if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom) and \
len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \
(min_area is None or a.pattern < min_area.pattern):
min_area, forward_fuse = a, True
return ([min_area], forward_fuse) if min_area else None

def _elemwise_depth(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1:
return None
a, r = list(dom.in_relations.items())[0]
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 and r != PrimLib.ELEMWISE:
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r != PrimLib.ELEMWISE or \
a.dom_op().output.shape != dom.dom_op().output.shape:
return None
return [a]
return [a], True

def _elemwise_width(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
return None
fused = []
for a, r in dom.in_relations.items():
if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom):
if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom) and \
a.dom_op().output.shape == dom.dom_op().output.shape:
fused.append(a)
return fused
return fused, True

def _broadcast_pat_exclude(dom, a, r):
if use_poly_reduce and a.pattern == PrimLib.REDUCE:
return dom.pattern > PrimLib.ELEMWISE or r > PrimLib.ELEMWISE
return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST

def _broadcast_depth(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1:
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \
dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH:
return None
a, r = list(dom.in_relations.items())[0]
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \
r != PrimLib.BROADCAST or len(a.ops) > self.BORADCAST_FUSE_DEPTH:
a, r = list(dom.out_relations.items())[0]
if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1:
return None
return [a]
return [a], False

def _broadcast_width(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH:
return None
fused = []
for a, r in dom.in_relations.items():
if a.pattern <= PrimLib.BROADCAST and r == PrimLib.BROADCAST and \
a.check_circle(dom) and len(a.ops) <= self.BORADCAST_FUSE_DEPTH:
fused.append(a)
return fused
for a, r in dom.out_relations.items():
if _broadcast_pat_exclude(dom, a, r) or not dom.check_circle(a) or \
(fused and fused[0].dom_op().output.shape != a.dom_op().output.shape):
return None
fused.append(a)
return fused, False

def _check_reduce_exclude(dom):
if use_poly_reduce:
return False
# exclude large all-reduce
if len(dom.ops[0].inputs[0].shape) == len(dom.ops[0].attrs["reduce_axis"]) and \
dom.ops[0].inputs[0].get_size() > 10000:
@@ -198,16 +266,22 @@ class GraphSplitByPattern:
return True
return False

def _reduce_pat_exclude(dom, a, r):
if len(a.ops) > self.REDUCE_FUSE_DEPTH:
return True
if use_poly_reduce:
return a.pattern > PrimLib.ELEMWISE or r > PrimLib.REDUCE or r == PrimLib.BROADCAST
return a.pattern > PrimLib.BROADCAST or r > PrimLib.REDUCE

def _reduce_depth(dom):
if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
return None
if _check_reduce_exclude(dom):
return None
a, r = list(dom.in_relations.items())[0]
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or \
r > PrimLib.REDUCE or len(a.ops) > self.REDUCE_FUSE_DEPTH:
if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1:
return None
return [a]
return [a], True

def _reduce_width(dom):
if dom.pattern != PrimLib.REDUCE:
@@ -216,18 +290,51 @@ class GraphSplitByPattern:
return None
fused = []
for a, r in dom.in_relations.items():
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.REDUCE and \
a.check_circle(dom) and len(a.ops) <= self.REDUCE_FUSE_DEPTH:
if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom):
fused.append(a)
return fused
return fused, True

def _tensor_size(tensor):
size = 1
for i in tensor.shape:
size *= i
return size

def _reduce_output(dom):
if dom.pattern != PrimLib.REDUCE:
return None
is_all_reduce = _tensor_size(dom.ops[0].output) == 1
# excluded large size all reduce
if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12:
return None
fused = []
for a, r in dom.out_relations.items():
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \
dom.check_circle(a) and not dom.reduce_out_exclude(a):
fused.append(a)
return fused, False

def _transpose(dom):
if len(dom.ops) != 1 or dom.ops[0].prim != "Transpose":
return None
fused = []
for a, _ in dom.in_relations.items():
if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom):
fused.append(a)
return fused, True

changed = True
while changed:
changed = self.fuse(_elemwise_depth)
changed = self.fuse(_reshape)
changed = self.fuse(_elemwise_depth) or changed
changed = self.fuse(_elemwise_width) or changed
changed = self.fuse(_broadcast_depth) or changed
changed = self.fuse(_broadcast_width) or changed
changed = self.fuse(_reduce_depth) or changed
changed = self.fuse(_reduce_width) or changed
changed = self.fuse(_broadcast_depth) or changed
changed = self.fuse(_broadcast_width) or changed
if use_poly_reduce:
changed = self.fuse(_reduce_output) or changed
self.fuse(_transpose)
subgraphs, graphmodes = self.to_subgraphs()
return subgraphs, graphmodes



Loading…
Cancel
Save