diff --git a/akg b/akg index aa167aa123..a70345a2ad 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit aa167aa12353813b179171cb59fc69e3ad4a4733 +Subproject commit a70345a2ad67aae8f57d547c44a5c97be04097ab diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index 482cf07e8f..34c7d7d857 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -253,6 +253,8 @@ class GraphSplitGpu(GraphSplitByPattern): REDUCE_FUSE_DEPTH = 20 def get_default_mode(self, op): + if op.prim == "BatchMatMul": + return self.Area.MODE_COMPOSITE if op.inputs[0].dtype == "float16" else self.Area.MODE_BASIC pattern = PrimLib.iter_type(op) return self.Area.MODE_BASIC if pattern == PrimLib.RESHAPE else self.Area.MODE_COMPOSITE diff --git a/mindspore/_extends/graph_kernel/model/model.py b/mindspore/_extends/graph_kernel/model/model.py index 7853c7787c..5900d3deec 100644 --- a/mindspore/_extends/graph_kernel/model/model.py +++ b/mindspore/_extends/graph_kernel/model/model.py @@ -71,6 +71,7 @@ class PrimLib: REDUCE = 4 TRANSFORM = 5 CONTROL = 6 + CONV = 7 class Prim: """Prim""" @@ -128,6 +129,7 @@ class PrimLib: default_reduce_relation, unknown_relation, unknown_relation, + unknown_relation, ] primtives = { @@ -171,6 +173,7 @@ class PrimLib: 'Transpose': Prim(TRANSFORM), 'Tile': Prim(BROADCAST), 'BroadcastTo': Prim(BROADCAST), + 'BatchMatMul': Prim(CONV), } default_primtive = Prim(UNKNOWN) diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index ca03d6ed79..7216fe754a 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -749,12 +749,12 @@ std::vector GetFusibleOpList() { prim::kPrimCast, prim::kPrimRealDiv}; #elif ENABLE_GPU std::vector fusible_basic_ops = { - prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, - prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, - prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, - prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater, - prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, - prim::kPrimCast, prim::kPrimExpandDims}; + prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, + prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, + prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, + prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater, + prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, + prim::kPrimCast, prim::kPrimExpandDims, prim::kPrimBatchMatMul}; #else std::vector fusible_basic_ops; #endif