From 195b1fe8d54be238af6615a4cec65a99e9ca8348 Mon Sep 17 00:00:00 2001 From: dayschan Date: Tue, 10 Nov 2020 15:38:04 +0800 Subject: [PATCH] Add Transpose into fusible list. --- akg | 2 +- mindspore/_extends/graph_kernel/model/graph_split.py | 2 ++ mindspore/_extends/graph_kernel/model/model.py | 2 ++ .../optimizer/graph_kernel/graph_kernel_helper.cc | 11 ++++++----- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/akg b/akg index 2956e64803..1866f35fe0 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit 2956e64803cad9b84316cdf2b25d034c5f944ccc +Subproject commit 1866f35fe0d1f10acfc1da0a69e9cb44cf37bb4c diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index 1b7d7fddf0..6c278037b0 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -30,6 +30,8 @@ class GraphSplitByPattern: self.in_relations = dict() # {area1: relation1, area2: relation2, ...} self.out_relations = dict() # {area1: relation1, area2: relation2, ...} self.mode = self.MODE_BASIC + if self.pattern == PrimLib.TRANSFORM: + self.mode = self.MODE_COMPOSITE def __str__(self): return '<' + '-'.join([op.output.name for op in self.ops]) + '>' diff --git a/mindspore/_extends/graph_kernel/model/model.py b/mindspore/_extends/graph_kernel/model/model.py index 5561ef213f..0f1a32d377 100644 --- a/mindspore/_extends/graph_kernel/model/model.py +++ b/mindspore/_extends/graph_kernel/model/model.py @@ -157,6 +157,8 @@ class PrimLib: 'ExpandDims': Prim(ELEMWISE), 'InplaceAssign': Prim(ELEMWISE), '@ReduceInit': Prim(ELEMWISE), + 'Reshape': Prim(ELEMWISE), + 'Transpose': Prim(TRANSFORM), } default_primtive = Prim(UNKNOWN) diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 9b0a5a1fa7..766c1a0547 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -726,11 +726,12 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p std::vector GetFusibleOpList() { std::vector fusible_basic_ops = { - prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, - prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, - prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, - prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, - prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape}; + prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, + prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, + prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, + prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, + prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, + prim::kPrimTranspose}; return fusible_basic_ops; }