Browse Source

!11262 【GraphKernel】Use akg BatchMatMul when dtype is float16.

From: @dayschan
Reviewed-by: @ckey_dou
Signed-off-by:
tags/v1.1.1
mindspore-ci-bot Gitee 5 years ago
parent
commit
072d0c4c8e
4 changed files with 12 additions and 7 deletions
  1. +1
    -1
      akg
  2. +2
    -0
      mindspore/_extends/graph_kernel/model/graph_split.py
  3. +3
    -0
      mindspore/_extends/graph_kernel/model/model.py
  4. +6
    -6
      mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc

+ 1
- 1
akg

@@ -1 +1 @@
Subproject commit aa167aa12353813b179171cb59fc69e3ad4a4733
Subproject commit a70345a2ad67aae8f57d547c44a5c97be04097ab

+ 2
- 0
mindspore/_extends/graph_kernel/model/graph_split.py View File

@@ -253,6 +253,8 @@ class GraphSplitGpu(GraphSplitByPattern):
REDUCE_FUSE_DEPTH = 20

def get_default_mode(self, op):
if op.prim == "BatchMatMul":
return self.Area.MODE_COMPOSITE if op.inputs[0].dtype == "float16" else self.Area.MODE_BASIC
pattern = PrimLib.iter_type(op)
return self.Area.MODE_BASIC if pattern == PrimLib.RESHAPE else self.Area.MODE_COMPOSITE



+ 3
- 0
mindspore/_extends/graph_kernel/model/model.py View File

@@ -71,6 +71,7 @@ class PrimLib:
REDUCE = 4
TRANSFORM = 5
CONTROL = 6
CONV = 7

class Prim:
"""Prim"""
@@ -128,6 +129,7 @@ class PrimLib:
default_reduce_relation,
unknown_relation,
unknown_relation,
unknown_relation,
]

primtives = {
@@ -171,6 +173,7 @@ class PrimLib:
'Transpose': Prim(TRANSFORM),
'Tile': Prim(BROADCAST),
'BroadcastTo': Prim(BROADCAST),
'BatchMatMul': Prim(CONV),
}

default_primtive = Prim(UNKNOWN)


+ 6
- 6
mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc View File

@@ -749,12 +749,12 @@ std::vector<PrimitivePtr> GetFusibleOpList() {
prim::kPrimCast, prim::kPrimRealDiv};
#elif ENABLE_GPU
std::vector<PrimitivePtr> fusible_basic_ops = {
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,
prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimCast, prim::kPrimExpandDims};
prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd,
prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog,
prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN,
prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, prim::kPrimGreater,
prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose,
prim::kPrimCast, prim::kPrimExpandDims, prim::kPrimBatchMatMul};
#else
std::vector<PrimitivePtr> fusible_basic_ops;
#endif


Loading…
Cancel
Save