| @@ -19,6 +19,34 @@ from .model import PrimLib, Graph, Tensor | |||
| class GraphSplitByPattern: | |||
| """Graph splitter""" | |||
| class ReachTable: | |||
| """Reachable table""" | |||
| def __init__(self, size): | |||
| self.map = [] | |||
| self.alive = set(range(size)) | |||
| for i in range(0, size): | |||
| self.map.append([False for j in range(0, size)]) | |||
| self.map[i][i] = True | |||
| def reachable(self, x, y): | |||
| """reachable from x to y""" | |||
| return self.map[x][y] | |||
| def sync(self, x, y): | |||
| """sync from y to x""" | |||
| for i in self.alive: | |||
| if self.map[y][i] and not self.map[x][i]: | |||
| self.map[x][i] = True | |||
| def fuse(self, x, y): | |||
| """fuse y to x""" | |||
| for i in self.alive: | |||
| if self.map[y][i] and not self.map[x][i]: | |||
| self.map[x][i] = True | |||
| if self.map[i][y] and not self.map[i][x]: | |||
| self.map[i][x] = True | |||
| self.alive.remove(y) | |||
| class Area: | |||
| """Area""" | |||
| MODE_BASIC = 1 | |||
| @@ -30,7 +58,7 @@ class GraphSplitByPattern: | |||
| self.stitch_ops = set() | |||
| self.stitch_atomic_ops = set() | |||
| def __init__(self, init_op, is_output, unique_id): | |||
| def __init__(self, init_op, is_output, unique_id, reach_tab): | |||
| self.pattern = PrimLib.iter_type(init_op) | |||
| self.ops = [init_op] | |||
| self.in_relations = dict() # {area1: relation1, area2: relation2, ...} | |||
| @@ -48,8 +76,8 @@ class GraphSplitByPattern: | |||
| else: | |||
| _gather_reduce_exclude(to) | |||
| _gather_reduce_exclude(init_op) | |||
| self.reach_map = dict() | |||
| self.unique_id = unique_id | |||
| self.reach_tab = reach_tab | |||
| def __str__(self): | |||
| return '<' + '-'.join([op.output.name for op in self.ops]) + '>' | |||
| @@ -78,6 +106,8 @@ class GraphSplitByPattern: | |||
| """Link outputs""" | |||
| for input_area, r in self.in_relations.items(): | |||
| input_area.out_relations[self] = r | |||
| for out, _ in self.out_relations.items(): | |||
| self.reach_tab.sync(self.unique_id, out.unique_id) | |||
| def update_stitch_info(self, stitch_info): | |||
| if stitch_info.stitch_ops: | |||
| @@ -124,23 +154,12 @@ class GraphSplitByPattern: | |||
| if area.output_excluded: | |||
| self.output_excluded.update(area.output_excluded) | |||
| self.update_stitch_info(area.stitch_info) | |||
| for to, reach in area.reach_map.items(): | |||
| if reach and not self.reach_map.get(to, False): | |||
| self.reach_map[to] = True | |||
| self.reach_tab.fuse(self.unique_id, area.unique_id) | |||
| def check_acyclic(self, to): | |||
| """Check circle. It returns false if circle exists""" | |||
| def _reached(area, to): | |||
| if to.unique_id in area.reach_map: | |||
| return area.reach_map[to.unique_id] | |||
| for out, _ in area.out_relations.items(): | |||
| if out == to or _reached(out, to): | |||
| area.reach_map[to.unique_id] = True | |||
| return True | |||
| area.reach_map[to.unique_id] = False | |||
| return False | |||
| for out, _ in self.out_relations.items(): | |||
| if out != to and _reached(out, to): | |||
| if out != to and self.reach_tab.reachable(out.unique_id, to.unique_id): | |||
| return False | |||
| return True | |||
| @@ -158,20 +177,21 @@ class GraphSplitByPattern: | |||
| self.graph = graph | |||
| self.areas = [] | |||
| self.flags = flags | |||
| self.reach_tab = self.ReachTable(len(graph.ops)) | |||
| area_map = {} | |||
| _, outputs = graph.deduce_parameters() | |||
| idx = 0 | |||
| for op in graph.ops: | |||
| is_output = op.output in outputs | |||
| a = self.Area(op, is_output, idx) | |||
| a = self.Area(op, is_output, idx, self.reach_tab) | |||
| idx += 1 | |||
| self.set_default_mode(a) | |||
| self.areas.append(a) | |||
| area_map[op] = a | |||
| for a in self.areas: | |||
| a.link_input(area_map) | |||
| for a in self.areas: | |||
| a.link_output() | |||
| for i in range(len(self.areas)-1, -1, -1): | |||
| self.areas[i].link_output() | |||
| def set_default_mode(self, area): | |||
| area.mode = self.get_default_mode(area.ops[0]) | |||
| @@ -260,7 +280,7 @@ class GraphSplitByPattern: | |||
| break | |||
| if out_reshape_ops: | |||
| for op in out_reshape_ops: | |||
| a = self.Area(op, False, -1) | |||
| a = self.Area(op, False, 0, self.reach_tab) | |||
| self.set_default_mode(a) | |||
| new_areas.append(a) | |||
| area.ops = remain_ops | |||