diff --git a/akg b/akg index 2956e64803..1866f35fe0 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit 2956e64803cad9b84316cdf2b25d034c5f944ccc +Subproject commit 1866f35fe0d1f10acfc1da0a69e9cb44cf37bb4c diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index 1b7d7fddf0..6c278037b0 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -30,6 +30,8 @@ class GraphSplitByPattern: self.in_relations = dict() # {area1: relation1, area2: relation2, ...} self.out_relations = dict() # {area1: relation1, area2: relation2, ...} self.mode = self.MODE_BASIC + if self.pattern == PrimLib.TRANSFORM: + self.mode = self.MODE_COMPOSITE def __str__(self): return '<' + '-'.join([op.output.name for op in self.ops]) + '>' diff --git a/mindspore/_extends/graph_kernel/model/model.py b/mindspore/_extends/graph_kernel/model/model.py index 5561ef213f..0f1a32d377 100644 --- a/mindspore/_extends/graph_kernel/model/model.py +++ b/mindspore/_extends/graph_kernel/model/model.py @@ -157,6 +157,8 @@ class PrimLib: 'ExpandDims': Prim(ELEMWISE), 'InplaceAssign': Prim(ELEMWISE), '@ReduceInit': Prim(ELEMWISE), + 'Reshape': Prim(ELEMWISE), + 'Transpose': Prim(TRANSFORM), } default_primtive = Prim(UNKNOWN) diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index 9b0a5a1fa7..766c1a0547 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -726,11 +726,12 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p std::vector GetFusibleOpList() { std::vector fusible_basic_ops = { - prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, - prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, - prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, - prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, - prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape}; + prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimTensorAdd, + prim::kPrimRealDiv, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, + prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimCast, + prim::kPrimAddN, prim::kPrimEqual, prim::kPrimReciprocal, prim::KPrimTransData, prim::kPrimSelect, + prim::kPrimGreater, prim::kPrimAssign, prim::kPrimReduceSum, prim::kPrimTanh, prim::kPrimReshape, + prim::kPrimTranspose}; return fusible_basic_ops; }