Browse Source

enable GraphKernel for TransData

pull/14514/head
hanhuifeng2020 4 years ago
parent
commit
25505642ce
3 changed files with 41 additions and 5 deletions
  1. +35
    -0
      mindspore/_extends/graph_kernel/model/graph_split.py
  2. +1
    -0
      mindspore/_extends/graph_kernel/model/model.py
  3. +5
    -5
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc

+ 35
- 0
mindspore/_extends/graph_kernel/model/graph_split.py View File

@@ -584,6 +584,39 @@ class GraphSplitAscend(GraphSplitByPattern):
fused.append(a)
return fused, False

def _transdata_pattern_support(dom, a):
transdata_op = dom.dom_op()

# Currently, if transdata has the pad, it is not used to fuse
def _has_pad():
res = False
input_shape = transdata_op.inputs[0].shape
output_shape = transdata_op.output.shape
cube_size = 16
for dim in input_shape[-2:]:
if dim % cube_size != 0:
res = True
for dim in output_shape[-2:]:
if dim % cube_size != 0:
res = True
return res
has_pad = _has_pad()
if has_pad:
return False

if a.dom_op().prim == "MatMul" and len(dom.ops) == 1:
return True
return False

def _transdata(dom):
if dom.dom_op().prim != "TransData":
return None
fused = []
for a, _ in dom.in_relations.items():
if _transdata_pattern_support(dom, a) and a.check_acyclic(dom):
fused.append(a)
return fused, True

changed = True
while changed:
changed = self.fuse(_reshape)
@@ -594,6 +627,8 @@ class GraphSplitAscend(GraphSplitByPattern):
changed = self.fuse(_broadcast_depth) or changed
changed = self.fuse(_broadcast_width) or changed
changed = self.fuse(_matmul_depth) or changed
self.fuse(_transdata)


def split(graph, target, flags):
"""Split graph"""


+ 1
- 0
mindspore/_extends/graph_kernel/model/model.py View File

@@ -186,6 +186,7 @@ class PrimLib:
'Tile': Prim(BROADCAST),
'BroadcastTo': Prim(BROADCAST),
'MatMul': Prim(OPAQUE),
'TransData': Prim(OPAQUE),
}

default_primtive = Prim(UNKNOWN)


+ 5
- 5
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc View File

@@ -596,11 +596,11 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p
std::vector<PrimitivePtr> GetFusibleOpList() {
#if ENABLE_D
std::vector<PrimitivePtr> fusible_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum};
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,
prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum, prim::KPrimTransData};
#elif ENABLE_GPU
std::vector<PrimitivePtr> fusible_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,


Loading…
Cancel
Save