diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index 3f1b5bc18b..c0286944b6 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -584,6 +584,39 @@ class GraphSplitAscend(GraphSplitByPattern): fused.append(a) return fused, False + def _transdata_pattern_support(dom, a): + transdata_op = dom.dom_op() + + # Currently, if transdata has the pad, it is not used to fuse + def _has_pad(): + res = False + input_shape = transdata_op.inputs[0].shape + output_shape = transdata_op.output.shape + cube_size = 16 + for dim in input_shape[-2:]: + if dim % cube_size != 0: + res = True + for dim in output_shape[-2:]: + if dim % cube_size != 0: + res = True + return res + has_pad = _has_pad() + if has_pad: + return False + + if a.dom_op().prim == "MatMul" and len(dom.ops) == 1: + return True + return False + + def _transdata(dom): + if dom.dom_op().prim != "TransData": + return None + fused = [] + for a, _ in dom.in_relations.items(): + if _transdata_pattern_support(dom, a) and a.check_acyclic(dom): + fused.append(a) + return fused, True + changed = True while changed: changed = self.fuse(_reshape) @@ -594,6 +627,8 @@ class GraphSplitAscend(GraphSplitByPattern): changed = self.fuse(_broadcast_depth) or changed changed = self.fuse(_broadcast_width) or changed changed = self.fuse(_matmul_depth) or changed + self.fuse(_transdata) + def split(graph, target, flags): """Split graph""" diff --git a/mindspore/_extends/graph_kernel/model/model.py b/mindspore/_extends/graph_kernel/model/model.py index 214c6e7984..359001b08a 100644 --- a/mindspore/_extends/graph_kernel/model/model.py +++ b/mindspore/_extends/graph_kernel/model/model.py @@ -186,6 +186,7 @@ class PrimLib: 'Tile': Prim(BROADCAST), 'BroadcastTo': Prim(BROADCAST), 'MatMul': Prim(OPAQUE), + 'TransData': Prim(OPAQUE), } default_primtive = Prim(UNKNOWN) diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 6738b031fa..1b904dc977 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -596,11 +596,11 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p std::vector GetFusibleOpList() { #if ENABLE_D std::vector fusible_basic_ops = { - prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd, - prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, - prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, - prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, - prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum}; + prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd, + prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, + prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, + prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, + prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum, prim::KPrimTransData}; #elif ENABLE_GPU std::vector fusible_basic_ops = { prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd,