| @@ -19,6 +19,34 @@ from .model import PrimLib, Graph, Tensor | |||||
| class GraphSplitByPattern: | class GraphSplitByPattern: | ||||
| """Graph splitter""" | """Graph splitter""" | ||||
| class ReachTable: | |||||
| """Reachable table""" | |||||
| def __init__(self, size): | |||||
| self.map = [] | |||||
| self.alive = set(range(size)) | |||||
| for i in range(0, size): | |||||
| self.map.append([False for j in range(0, size)]) | |||||
| self.map[i][i] = True | |||||
| def reachable(self, x, y): | |||||
| """reachable from x to y""" | |||||
| return self.map[x][y] | |||||
| def sync(self, x, y): | |||||
| """sync from y to x""" | |||||
| for i in self.alive: | |||||
| if self.map[y][i] and not self.map[x][i]: | |||||
| self.map[x][i] = True | |||||
| def fuse(self, x, y): | |||||
| """fuse y to x""" | |||||
| for i in self.alive: | |||||
| if self.map[y][i] and not self.map[x][i]: | |||||
| self.map[x][i] = True | |||||
| if self.map[i][y] and not self.map[i][x]: | |||||
| self.map[i][x] = True | |||||
| self.alive.remove(y) | |||||
| class Area: | class Area: | ||||
| """Area""" | """Area""" | ||||
| MODE_BASIC = 1 | MODE_BASIC = 1 | ||||
| @@ -30,7 +58,7 @@ class GraphSplitByPattern: | |||||
| self.stitch_ops = set() | self.stitch_ops = set() | ||||
| self.stitch_atomic_ops = set() | self.stitch_atomic_ops = set() | ||||
| def __init__(self, init_op, is_output, unique_id): | |||||
| def __init__(self, init_op, is_output, unique_id, reach_tab): | |||||
| self.pattern = PrimLib.iter_type(init_op) | self.pattern = PrimLib.iter_type(init_op) | ||||
| self.ops = [init_op] | self.ops = [init_op] | ||||
| self.in_relations = dict() # {area1: relation1, area2: relation2, ...} | self.in_relations = dict() # {area1: relation1, area2: relation2, ...} | ||||
| @@ -48,8 +76,8 @@ class GraphSplitByPattern: | |||||
| else: | else: | ||||
| _gather_reduce_exclude(to) | _gather_reduce_exclude(to) | ||||
| _gather_reduce_exclude(init_op) | _gather_reduce_exclude(init_op) | ||||
| self.reach_map = dict() | |||||
| self.unique_id = unique_id | self.unique_id = unique_id | ||||
| self.reach_tab = reach_tab | |||||
| def __str__(self): | def __str__(self): | ||||
| return '<' + '-'.join([op.output.name for op in self.ops]) + '>' | return '<' + '-'.join([op.output.name for op in self.ops]) + '>' | ||||
| @@ -78,6 +106,8 @@ class GraphSplitByPattern: | |||||
| """Link outputs""" | """Link outputs""" | ||||
| for input_area, r in self.in_relations.items(): | for input_area, r in self.in_relations.items(): | ||||
| input_area.out_relations[self] = r | input_area.out_relations[self] = r | ||||
| for out, _ in self.out_relations.items(): | |||||
| self.reach_tab.sync(self.unique_id, out.unique_id) | |||||
| def update_stitch_info(self, stitch_info): | def update_stitch_info(self, stitch_info): | ||||
| if stitch_info.stitch_ops: | if stitch_info.stitch_ops: | ||||
| @@ -124,23 +154,12 @@ class GraphSplitByPattern: | |||||
| if area.output_excluded: | if area.output_excluded: | ||||
| self.output_excluded.update(area.output_excluded) | self.output_excluded.update(area.output_excluded) | ||||
| self.update_stitch_info(area.stitch_info) | self.update_stitch_info(area.stitch_info) | ||||
| for to, reach in area.reach_map.items(): | |||||
| if reach and not self.reach_map.get(to, False): | |||||
| self.reach_map[to] = True | |||||
| self.reach_tab.fuse(self.unique_id, area.unique_id) | |||||
| def check_acyclic(self, to): | def check_acyclic(self, to): | ||||
| """Check circle. It returns false if circle exists""" | """Check circle. It returns false if circle exists""" | ||||
| def _reached(area, to): | |||||
| if to.unique_id in area.reach_map: | |||||
| return area.reach_map[to.unique_id] | |||||
| for out, _ in area.out_relations.items(): | |||||
| if out == to or _reached(out, to): | |||||
| area.reach_map[to.unique_id] = True | |||||
| return True | |||||
| area.reach_map[to.unique_id] = False | |||||
| return False | |||||
| for out, _ in self.out_relations.items(): | for out, _ in self.out_relations.items(): | ||||
| if out != to and _reached(out, to): | |||||
| if out != to and self.reach_tab.reachable(out.unique_id, to.unique_id): | |||||
| return False | return False | ||||
| return True | return True | ||||
| @@ -158,20 +177,21 @@ class GraphSplitByPattern: | |||||
| self.graph = graph | self.graph = graph | ||||
| self.areas = [] | self.areas = [] | ||||
| self.flags = flags | self.flags = flags | ||||
| self.reach_tab = self.ReachTable(len(graph.ops)) | |||||
| area_map = {} | area_map = {} | ||||
| _, outputs = graph.deduce_parameters() | _, outputs = graph.deduce_parameters() | ||||
| idx = 0 | idx = 0 | ||||
| for op in graph.ops: | for op in graph.ops: | ||||
| is_output = op.output in outputs | is_output = op.output in outputs | ||||
| a = self.Area(op, is_output, idx) | |||||
| a = self.Area(op, is_output, idx, self.reach_tab) | |||||
| idx += 1 | idx += 1 | ||||
| self.set_default_mode(a) | self.set_default_mode(a) | ||||
| self.areas.append(a) | self.areas.append(a) | ||||
| area_map[op] = a | area_map[op] = a | ||||
| for a in self.areas: | for a in self.areas: | ||||
| a.link_input(area_map) | a.link_input(area_map) | ||||
| for a in self.areas: | |||||
| a.link_output() | |||||
| for i in range(len(self.areas)-1, -1, -1): | |||||
| self.areas[i].link_output() | |||||
| def set_default_mode(self, area): | def set_default_mode(self, area): | ||||
| area.mode = self.get_default_mode(area.ops[0]) | area.mode = self.get_default_mode(area.ops[0]) | ||||
| @@ -260,7 +280,7 @@ class GraphSplitByPattern: | |||||
| break | break | ||||
| if out_reshape_ops: | if out_reshape_ops: | ||||
| for op in out_reshape_ops: | for op in out_reshape_ops: | ||||
| a = self.Area(op, False, -1) | |||||
| a = self.Area(op, False, 0, self.reach_tab) | |||||
| self.set_default_mode(a) | self.set_default_mode(a) | ||||
| new_areas.append(a) | new_areas.append(a) | ||||
| area.ops = remain_ops | area.ops = remain_ops | ||||