Browse Source

update graph kernel split model for Ascend

tags/v1.2.0-rc1
Gaoxiong 5 years ago
parent
commit
32e19e83da
2 changed files with 170 additions and 51 deletions
  1. +168
    -50
      mindspore/_extends/graph_kernel/model/graph_split.py
  2. +2
    -1
      mindspore/_extends/graph_kernel/splitter.py

+ 168
- 50
mindspore/_extends/graph_kernel/model/graph_split.py View File

@@ -16,9 +16,6 @@
from functools import reduce
from .model import PrimLib, Graph, Tensor

use_poly_reduce = True


class GraphSplitByPattern:
"""Graph splitter"""
class Area:
@@ -32,7 +29,6 @@ class GraphSplitByPattern:
self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
self.out_relations = dict() # {area1: relation1, area2: relation2, ...}
self.mode = None
self.set_default_mode()
self.is_output = is_output
self.output_excluded = set()
if self.pattern == PrimLib.REDUCE:
@@ -51,17 +47,6 @@ class GraphSplitByPattern:
def __repr__(self):
return str(self)

def set_default_mode(self):
def _get_default_mode(op):
if op.prim == "AddN":
return self.MODE_COMPOSITE
pattern = PrimLib.iter_type(op)
if pattern == PrimLib.TRANSFORM or pattern == PrimLib.BROADCAST or \
(use_poly_reduce and pattern == PrimLib.REDUCE):
return self.MODE_COMPOSITE
return self.MODE_BASIC
self.mode = _get_default_mode(self.ops[0])

def get_relation(self, op, i):
relation = PrimLib.UNKNOWN
_, elem_relation = PrimLib.input_relation(op, i)
@@ -145,9 +130,6 @@ class GraphSplitByPattern:
return True
return False

BORADCAST_FUSE_DEPTH = 20
REDUCE_FUSE_DEPTH = 20

def __init__(self, graph):
self.graph = graph
self.areas = []
@@ -156,6 +138,7 @@ class GraphSplitByPattern:
for op in graph.ops:
is_output = op.output in outputs
a = self.Area(op, is_output)
self.set_default_mode(a)
self.areas.append(a)
area_map[op] = a
for a in self.areas:
@@ -163,6 +146,9 @@ class GraphSplitByPattern:
for a in self.areas:
a.link_output()

def set_default_mode(self, area):
area.mode = self.get_default_mode(area.ops[0])

def fuse(self, selector):
"""Fuse areas"""
changed = False
@@ -200,6 +186,51 @@ class GraphSplitByPattern:
return subgraphs, graphmodes

def split(self):
"""Split graph by pattern"""
self.do_split()
# The reshape should not be output node
# Note: after this function, the input output relation is not maintained.
self.split_output_reshapes()
subgraphs, graphmodes = self.to_subgraphs()
return subgraphs, graphmodes

def split_output_reshapes(self):
"""Force split the output reshapes into other new """
new_areas = []
for area in self.areas:
out_reshape_ops = [op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE]
remain_ops = [op for op in area.ops if op not in out_reshape_ops]
if not remain_ops or not out_reshape_ops:
continue
changed = True
while changed:
changed = False
for op in out_reshape_ops:
if any([to_op in remain_ops for to_op in op.output.to_ops]):
out_reshape_ops.remove(op)
remain_ops.append(op)
changed = True
break
if out_reshape_ops:
for op in out_reshape_ops:
new_areas.append(self.Area(op, False))
area.ops = remain_ops
if len(remain_ops) == 1:
self.set_default_mode(area)
if new_areas:
self.areas += new_areas

use_poly_reduce = True
class GraphSplitGpu(GraphSplitByPattern):
"""Graph splitter"""
BORADCAST_FUSE_DEPTH = 20
REDUCE_FUSE_DEPTH = 20

def get_default_mode(self, op):
pattern = PrimLib.iter_type(op)
return self.Area.MODE_BASIC if pattern == PrimLib.RESHAPE else self.Area.MODE_COMPOSITE

def do_split(self):
"""Split graph by pattern"""
def _reshape(dom):
if dom.pattern != PrimLib.RESHAPE:
@@ -367,40 +398,127 @@ class GraphSplitByPattern:
changed = self.fuse(_reduce_output) or changed
self.fuse(_transpose)

# The reshape should not be output node
# Note: after this function, the input output relation is not maintained.
self.split_output_reshapes()
class GraphSplitAscend(GraphSplitByPattern):
"""Graph splitter"""
BORADCAST_FUSE_DEPTH = 6
REDUCE_FUSE_DEPTH = 10

subgraphs, graphmodes = self.to_subgraphs()
return subgraphs, graphmodes
def get_default_mode(self, op):
if op.prim in ("Tile", "BroadcastTo"):
return self.Area.MODE_COMPOSITE
return self.Area.MODE_BASIC

def split_output_reshapes(self):
"""Force split the output reshapes into other new """
new_areas = []
for area in self.areas:
out_reshape_ops = [op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE]
remain_ops = [op for op in area.ops if op not in out_reshape_ops]
if not remain_ops or not out_reshape_ops:
continue
changed = True
while changed:
changed = False
for op in out_reshape_ops:
if any([to_op in remain_ops for to_op in op.output.to_ops]):
out_reshape_ops.remove(op)
remain_ops.append(op)
changed = True
break
if out_reshape_ops:
for op in out_reshape_ops:
new_areas.append(self.Area(op, False))
area.ops = remain_ops
if len(remain_ops) == 1:
area.set_default_mode()
if new_areas:
self.areas += new_areas
def do_split(self):
"""Split graph by pattern"""
def _tensor_size(tensor):
size = 1
for i in tensor.shape:
size *= i
return size

def _likely_multicore(dom):
op = dom.dom_op()
iter_size = _tensor_size(op.output if not PrimLib.is_reduce(op) else op.inputs[0])
return iter_size > 1024

def _reshape(dom):
if dom.pattern != PrimLib.RESHAPE:
return None
min_area, forward_fuse = None, False
for a, _ in dom.out_relations.items():
if a.pattern <= PrimLib.BROADCAST and dom.check_circle(a) and \
(min_area is None or a.pattern < min_area.pattern):
min_area = a
for a, _ in dom.in_relations.items():
if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom) and \
len(dom.ops[0].inputs[0].to_ops) == 1 and not a.is_output and \
(min_area is None or a.pattern < min_area.pattern):
min_area, forward_fuse = a, True
return ([min_area], forward_fuse) if min_area else None

def _elemwise_depth(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.in_relations) != 1:
return None
a, r = list(dom.in_relations.items())[0]
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r != PrimLib.ELEMWISE or \
a.dom_op().output.shape != dom.dom_op().output.shape:
return None
return [a], True

def _elemwise_width(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
return None
fused = []
for a, r in dom.in_relations.items():
if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom) and \
a.dom_op().output.shape == dom.dom_op().output.shape:
fused.append(a)
return fused, True

def _broadcast_pat_exclude(dom, a, r):
if _likely_multicore(a) and (dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH):
return True
return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST

def _broadcast_depth(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1:
return None
a, r = list(dom.out_relations.items())[0]
if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1:
return None
return [a], False

def _broadcast_width(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
return None
fused = []
for a, r in dom.out_relations.items():
if _broadcast_pat_exclude(dom, a, r) or not dom.check_circle(a) or \
(fused and fused[0].dom_op().output.shape != a.dom_op().output.shape):
return None
fused.append(a)
return fused, False

def _reduce_pat_exclude(dom, a, r):
if len(a.ops) > self.REDUCE_FUSE_DEPTH:
return True
if r == PrimLib.BROADCAST and _likely_multicore(dom) and \
(dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH):
return True
return a.pattern > PrimLib.BROADCAST or r > PrimLib.REDUCE

def _reduce_depth(dom):
if dom.pattern != PrimLib.REDUCE or len(dom.in_relations) != 1:
return None
a, r = list(dom.in_relations.items())[0]
if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1:
return None
return [a], True

def _reduce_width(dom):
if dom.pattern != PrimLib.REDUCE:
return None
fused = []
for a, r in dom.in_relations.items():
if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom):
fused.append(a)
return fused, True

changed = True
while changed:
changed = self.fuse(_reshape)
changed = self.fuse(_elemwise_depth) or changed
changed = self.fuse(_elemwise_width) or changed
changed = self.fuse(_reduce_depth) or changed
changed = self.fuse(_reduce_width) or changed
changed = self.fuse(_broadcast_depth) or changed
changed = self.fuse(_broadcast_width) or changed

def split(graph):
def split(graph, target):
"""Split graph"""
return GraphSplitByPattern(graph).split()
result = None
if target == "cuda":
result = GraphSplitGpu(graph).split()
else:
result = GraphSplitAscend(graph).split()
return result

+ 2
- 1
mindspore/_extends/graph_kernel/splitter.py View File

@@ -25,8 +25,9 @@ def split_with_json(json_str: str):
"""Call costmodel to split GraphKernel"""
try:
graph_desc = json.loads(json_str)
target = graph_desc['process']
comp = model.load_composite(graph_desc)
graph_split, graph_mode = model.split(comp.graph)
graph_split, graph_mode = model.split(comp.graph, target)
is_multi_graph = len(graph_split) > 1
graph_list = list(map(comp.dump, graph_split))
result = {"multi_graph": is_multi_graph,


Loading…
Cancel
Save