|
|
|
@@ -19,6 +19,34 @@ from .model import PrimLib, Graph, Tensor |
|
|
|
|
|
|
|
class GraphSplitByPattern: |
|
|
|
"""Graph splitter""" |
|
|
|
class ReachTable: |
|
|
|
"""Reachable table""" |
|
|
|
def __init__(self, size): |
|
|
|
self.map = [] |
|
|
|
self.alive = set(range(size)) |
|
|
|
for i in range(0, size): |
|
|
|
self.map.append([False for j in range(0, size)]) |
|
|
|
self.map[i][i] = True |
|
|
|
|
|
|
|
def reachable(self, x, y): |
|
|
|
"""reachable from x to y""" |
|
|
|
return self.map[x][y] |
|
|
|
|
|
|
|
def sync(self, x, y): |
|
|
|
"""sync from y to x""" |
|
|
|
for i in self.alive: |
|
|
|
if self.map[y][i] and not self.map[x][i]: |
|
|
|
self.map[x][i] = True |
|
|
|
|
|
|
|
def fuse(self, x, y): |
|
|
|
"""fuse y to x""" |
|
|
|
for i in self.alive: |
|
|
|
if self.map[y][i] and not self.map[x][i]: |
|
|
|
self.map[x][i] = True |
|
|
|
if self.map[i][y] and not self.map[i][x]: |
|
|
|
self.map[i][x] = True |
|
|
|
self.alive.remove(y) |
|
|
|
|
|
|
|
class Area: |
|
|
|
"""Area""" |
|
|
|
MODE_BASIC = 1 |
|
|
|
@@ -30,7 +58,7 @@ class GraphSplitByPattern: |
|
|
|
self.stitch_ops = set() |
|
|
|
self.stitch_atomic_ops = set() |
|
|
|
|
|
|
|
def __init__(self, init_op, is_output, unique_id): |
|
|
|
def __init__(self, init_op, is_output, unique_id, reach_tab): |
|
|
|
self.pattern = PrimLib.iter_type(init_op) |
|
|
|
self.ops = [init_op] |
|
|
|
self.in_relations = dict() # {area1: relation1, area2: relation2, ...} |
|
|
|
@@ -48,8 +76,8 @@ class GraphSplitByPattern: |
|
|
|
else: |
|
|
|
_gather_reduce_exclude(to) |
|
|
|
_gather_reduce_exclude(init_op) |
|
|
|
self.reach_map = dict() |
|
|
|
self.unique_id = unique_id |
|
|
|
self.reach_tab = reach_tab |
|
|
|
|
|
|
|
def __str__(self): |
|
|
|
return '<' + '-'.join([op.output.name for op in self.ops]) + '>' |
|
|
|
@@ -78,6 +106,8 @@ class GraphSplitByPattern: |
|
|
|
"""Link outputs""" |
|
|
|
for input_area, r in self.in_relations.items(): |
|
|
|
input_area.out_relations[self] = r |
|
|
|
for out, _ in self.out_relations.items(): |
|
|
|
self.reach_tab.sync(self.unique_id, out.unique_id) |
|
|
|
|
|
|
|
def update_stitch_info(self, stitch_info): |
|
|
|
if stitch_info.stitch_ops: |
|
|
|
@@ -124,23 +154,12 @@ class GraphSplitByPattern: |
|
|
|
if area.output_excluded: |
|
|
|
self.output_excluded.update(area.output_excluded) |
|
|
|
self.update_stitch_info(area.stitch_info) |
|
|
|
for to, reach in area.reach_map.items(): |
|
|
|
if reach and not self.reach_map.get(to, False): |
|
|
|
self.reach_map[to] = True |
|
|
|
self.reach_tab.fuse(self.unique_id, area.unique_id) |
|
|
|
|
|
|
|
def check_acyclic(self, to): |
|
|
|
"""Check circle. It returns false if circle exists""" |
|
|
|
def _reached(area, to): |
|
|
|
if to.unique_id in area.reach_map: |
|
|
|
return area.reach_map[to.unique_id] |
|
|
|
for out, _ in area.out_relations.items(): |
|
|
|
if out == to or _reached(out, to): |
|
|
|
area.reach_map[to.unique_id] = True |
|
|
|
return True |
|
|
|
area.reach_map[to.unique_id] = False |
|
|
|
return False |
|
|
|
for out, _ in self.out_relations.items(): |
|
|
|
if out != to and _reached(out, to): |
|
|
|
if out != to and self.reach_tab.reachable(out.unique_id, to.unique_id): |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
@@ -158,20 +177,21 @@ class GraphSplitByPattern: |
|
|
|
self.graph = graph |
|
|
|
self.areas = [] |
|
|
|
self.flags = flags |
|
|
|
self.reach_tab = self.ReachTable(len(graph.ops)) |
|
|
|
area_map = {} |
|
|
|
_, outputs = graph.deduce_parameters() |
|
|
|
idx = 0 |
|
|
|
for op in graph.ops: |
|
|
|
is_output = op.output in outputs |
|
|
|
a = self.Area(op, is_output, idx) |
|
|
|
a = self.Area(op, is_output, idx, self.reach_tab) |
|
|
|
idx += 1 |
|
|
|
self.set_default_mode(a) |
|
|
|
self.areas.append(a) |
|
|
|
area_map[op] = a |
|
|
|
for a in self.areas: |
|
|
|
a.link_input(area_map) |
|
|
|
for a in self.areas: |
|
|
|
a.link_output() |
|
|
|
for i in range(len(self.areas)-1, -1, -1): |
|
|
|
self.areas[i].link_output() |
|
|
|
|
|
|
|
def set_default_mode(self, area): |
|
|
|
area.mode = self.get_default_mode(area.ops[0]) |
|
|
|
@@ -260,7 +280,7 @@ class GraphSplitByPattern: |
|
|
|
break |
|
|
|
if out_reshape_ops: |
|
|
|
for op in out_reshape_ops: |
|
|
|
a = self.Area(op, False, -1) |
|
|
|
a = self.Area(op, False, 0, self.reach_tab) |
|
|
|
self.set_default_mode(a) |
|
|
|
new_areas.append(a) |
|
|
|
area.ops = remain_ops |
|
|
|
|