| @@ -257,7 +257,7 @@ class GraphSplitByPattern: | |||
| return self.ops[0] | |||
| def reduce_out_exclude(self, area): | |||
| """Check whether op is redcue_out_exclude """ | |||
| """Check whether op is reduce_out_exclude """ | |||
| if self.output_excluded: | |||
| for op in self.output_excluded: | |||
| if op in area.ops: | |||
| @@ -678,7 +678,7 @@ class GraphSplitByPattern: | |||
| class GraphSplitGpu(GraphSplitByPattern): | |||
| """Graph splitter""" | |||
| BORADCAST_FUSE_DEPTH = 20 | |||
| BROADCAST_FUSE_DEPTH = 20 | |||
| REDUCE_FUSE_DEPTH = 20 | |||
| def get_default_mode(self, op): | |||
| @@ -733,7 +733,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| def _broadcast_depth(dom): | |||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \ | |||
| dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH: | |||
| dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH: | |||
| return None | |||
| a, r = list(dom.out_relations.items())[0] | |||
| if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1: | |||
| @@ -742,7 +742,7 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| def _broadcast_width(dom): | |||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \ | |||
| dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH: | |||
| dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH: | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.out_relations.items(): | |||
| @@ -956,11 +956,11 @@ class GraphSplitGpu(GraphSplitByPattern): | |||
| class GraphSplitAscend(GraphSplitByPattern): | |||
| """Graph splitter""" | |||
| BORADCAST_FUSE_DEPTH = 6 | |||
| BROADCAST_FUSE_DEPTH = 6 | |||
| REDUCE_FUSE_DEPTH = 10 | |||
| def get_default_mode(self, op): | |||
| """Get efault mode for Ascend""" | |||
| """Get default mode for Ascend""" | |||
| def _dtype_same(tensors): | |||
| dtype = tensors[0].dtype | |||
| @@ -1019,7 +1019,7 @@ class GraphSplitAscend(GraphSplitByPattern): | |||
| return fused, True | |||
| def _broadcast_pat_exclude(dom, a, r): | |||
| if _likely_multicore(a) and (dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH): | |||
| if _likely_multicore(a) and (dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH): | |||
| return True | |||
| return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST | |||
| @@ -1046,7 +1046,7 @@ class GraphSplitAscend(GraphSplitByPattern): | |||
| if len(a.ops) > self.REDUCE_FUSE_DEPTH: | |||
| return True | |||
| if r == PrimLib.BROADCAST and _likely_multicore(dom) and \ | |||
| (dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH): | |||
| (dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH): | |||
| return True | |||
| return a.pattern > PrimLib.BROADCAST or r > PrimLib.REDUCE | |||
| @@ -1193,7 +1193,7 @@ class GraphSplitAscend(GraphSplitByPattern): | |||
| class GraphSplitCpu(GraphSplitByPattern): | |||
| """Graph splitter""" | |||
| BORADCAST_FUSE_DEPTH = 20 | |||
| BROADCAST_FUSE_DEPTH = 20 | |||
| REDUCE_FUSE_DEPTH = 20 | |||
| def get_default_mode(self, op): | |||
| @@ -1244,7 +1244,7 @@ class GraphSplitCpu(GraphSplitByPattern): | |||
| def _broadcast_depth(dom): | |||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \ | |||
| dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH: | |||
| dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH: | |||
| return None | |||
| a, r = list(dom.out_relations.items())[0] | |||
| if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1: | |||
| @@ -1253,7 +1253,7 @@ class GraphSplitCpu(GraphSplitByPattern): | |||
| def _broadcast_width(dom): | |||
| if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \ | |||
| dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH: | |||
| dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH: | |||
| return None | |||
| fused = [] | |||
| for a, r in dom.out_relations.items(): | |||