|
|
@@ -15,7 +15,9 @@ |
|
|
"""Cost model splitter""" |
|
|
"""Cost model splitter""" |
|
|
import os |
|
|
import os |
|
|
from functools import reduce |
|
|
from functools import reduce |
|
|
|
|
|
from mindspore import log as logger |
|
|
from .model import PrimLib, Graph, Tensor |
|
|
from .model import PrimLib, Graph, Tensor |
|
|
|
|
|
from .model import DataFormat as DF |
|
|
|
|
|
|
|
|
class GraphSplitByPattern: |
|
|
class GraphSplitByPattern: |
|
|
"""Graph splitter""" |
|
|
"""Graph splitter""" |
|
|
@@ -674,6 +676,21 @@ class GraphSplitAscend(GraphSplitByPattern): |
|
|
|
|
|
|
|
|
if a.dom_op().prim == "MatMul" and len(dom.ops) == 1: |
|
|
if a.dom_op().prim == "MatMul" and len(dom.ops) == 1: |
|
|
return True |
|
|
return True |
|
|
|
|
|
|
|
|
|
|
|
# reshape/elewise/broadcast + transdata |
|
|
|
|
|
if a.pattern <= PrimLib.BROADCAST and len(dom.ops) == 1: |
|
|
|
|
|
op_attrs = dom.dom_op().attrs |
|
|
|
|
|
if 'src_format' not in op_attrs.keys() \ |
|
|
|
|
|
or 'dst_format' not in op_attrs.keys(): |
|
|
|
|
|
logger.error("src_format or dst_format not be found in the attrs of Transdata op") |
|
|
|
|
|
return False |
|
|
|
|
|
src_format, dst_format = op_attrs['src_format'], op_attrs['dst_format'] |
|
|
|
|
|
if src_format == DF.FRAC_NZ and dst_format in (DF.DEFAULT, DF.NCHW): |
|
|
|
|
|
return True |
|
|
|
|
|
# For the Default/NCHW to FRAC_NZ, currently only the Cast+Transdata is supported |
|
|
|
|
|
if src_format in (DF.DEFAULT, DF.NCHW) and dst_format == DF.FRAC_NZ\ |
|
|
|
|
|
and len(a.ops) == 1 and a.dom_op().prim == "Cast" and not a.is_output: |
|
|
|
|
|
return True |
|
|
return False |
|
|
return False |
|
|
|
|
|
|
|
|
def _transdata(dom): |
|
|
def _transdata(dom): |
|
|
|