Browse Source

Corrected typos

tags/v1.6.0
Samuel Batissou 4 years ago
parent
commit
b5757a42d2
1 changed files with 11 additions and 11 deletions
  1. +11
    -11
      mindspore/_extends/graph_kernel/model/graph_split.py

+ 11
- 11
mindspore/_extends/graph_kernel/model/graph_split.py View File

@@ -257,7 +257,7 @@ class GraphSplitByPattern:
return self.ops[0]

def reduce_out_exclude(self, area):
"""Check whether op is redcue_out_exclude """
"""Check whether op is reduce_out_exclude """
if self.output_excluded:
for op in self.output_excluded:
if op in area.ops:
@@ -678,7 +678,7 @@ class GraphSplitByPattern:

class GraphSplitGpu(GraphSplitByPattern):
"""Graph splitter"""
BORADCAST_FUSE_DEPTH = 20
BROADCAST_FUSE_DEPTH = 20
REDUCE_FUSE_DEPTH = 20

def get_default_mode(self, op):
@@ -733,7 +733,7 @@ class GraphSplitGpu(GraphSplitByPattern):

def _broadcast_depth(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \
dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH:
dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
return None
a, r = list(dom.out_relations.items())[0]
if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1:
@@ -742,7 +742,7 @@ class GraphSplitGpu(GraphSplitByPattern):

def _broadcast_width(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH:
dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
return None
fused = []
for a, r in dom.out_relations.items():
@@ -956,11 +956,11 @@ class GraphSplitGpu(GraphSplitByPattern):

class GraphSplitAscend(GraphSplitByPattern):
"""Graph splitter"""
BORADCAST_FUSE_DEPTH = 6
BROADCAST_FUSE_DEPTH = 6
REDUCE_FUSE_DEPTH = 10

def get_default_mode(self, op):
"""Get efault mode for Ascend"""
"""Get default mode for Ascend"""

def _dtype_same(tensors):
dtype = tensors[0].dtype
@@ -1019,7 +1019,7 @@ class GraphSplitAscend(GraphSplitByPattern):
return fused, True

def _broadcast_pat_exclude(dom, a, r):
if _likely_multicore(a) and (dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH):
if _likely_multicore(a) and (dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH):
return True
return a.pattern > PrimLib.REDUCE or r > PrimLib.BROADCAST

@@ -1046,7 +1046,7 @@ class GraphSplitAscend(GraphSplitByPattern):
if len(a.ops) > self.REDUCE_FUSE_DEPTH:
return True
if r == PrimLib.BROADCAST and _likely_multicore(dom) and \
(dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH):
(dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH):
return True
return a.pattern > PrimLib.BROADCAST or r > PrimLib.REDUCE

@@ -1193,7 +1193,7 @@ class GraphSplitAscend(GraphSplitByPattern):

class GraphSplitCpu(GraphSplitByPattern):
"""Graph splitter"""
BORADCAST_FUSE_DEPTH = 20
BROADCAST_FUSE_DEPTH = 20
REDUCE_FUSE_DEPTH = 20

def get_default_mode(self, op):
@@ -1244,7 +1244,7 @@ class GraphSplitCpu(GraphSplitByPattern):

def _broadcast_depth(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \
dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH:
dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
return None
a, r = list(dom.out_relations.items())[0]
if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1:
@@ -1253,7 +1253,7 @@ class GraphSplitCpu(GraphSplitByPattern):

def _broadcast_width(dom):
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH:
dom.is_output or len(dom.ops) > self.BROADCAST_FUSE_DEPTH:
return None
fused = []
for a, r in dom.out_relations.items():


Loading…
Cancel
Save