|
|
|
@@ -145,9 +145,10 @@ class GraphSplitByPattern: |
|
|
|
return True |
|
|
|
return False |
|
|
|
|
|
|
|
def __init__(self, graph): |
|
|
|
def __init__(self, graph, flags): |
|
|
|
self.graph = graph |
|
|
|
self.areas = [] |
|
|
|
self.flags = flags |
|
|
|
area_map = {} |
|
|
|
_, outputs = graph.deduce_parameters() |
|
|
|
for op in graph.ops: |
|
|
|
@@ -450,6 +451,7 @@ class GraphSplitGpu(GraphSplitByPattern): |
|
|
|
fused.append(a) |
|
|
|
return fused, True |
|
|
|
|
|
|
|
enable_stitch_fusion = self.flags.get("enable_stitch_fusion", False) |
|
|
|
changed = True |
|
|
|
while changed: |
|
|
|
changed = self.fuse(_reshape) |
|
|
|
@@ -461,7 +463,8 @@ class GraphSplitGpu(GraphSplitByPattern): |
|
|
|
changed = self.fuse(_broadcast_width) or changed |
|
|
|
if use_poly_reduce: |
|
|
|
changed = self.fuse(_reduce_output) or changed |
|
|
|
changed = self.fuse(_reduce_stitch) or changed |
|
|
|
if enable_stitch_fusion: |
|
|
|
changed = self.fuse(_reduce_stitch) or changed |
|
|
|
self.fuse(_transpose) |
|
|
|
|
|
|
|
class GraphSplitAscend(GraphSplitByPattern): |
|
|
|
@@ -582,11 +585,11 @@ class GraphSplitAscend(GraphSplitByPattern): |
|
|
|
changed = self.fuse(_broadcast_depth) or changed |
|
|
|
changed = self.fuse(_broadcast_width) or changed |
|
|
|
|
|
|
|
def split(graph, target): |
|
|
|
def split(graph, target, flags): |
|
|
|
"""Split graph""" |
|
|
|
result = None |
|
|
|
if target == "cuda": |
|
|
|
result = GraphSplitGpu(graph).split() |
|
|
|
result = GraphSplitGpu(graph, flags).split() |
|
|
|
else: |
|
|
|
result = GraphSplitAscend(graph).split() |
|
|
|
result = GraphSplitAscend(graph, flags).split() |
|
|
|
return result |