| @@ -584,6 +584,39 @@ class GraphSplitAscend(GraphSplitByPattern): | |||||
| fused.append(a) | fused.append(a) | ||||
| return fused, False | return fused, False | ||||
| def _transdata_pattern_support(dom, a): | |||||
| transdata_op = dom.dom_op() | |||||
| # Currently, if transdata has the pad, it is not used to fuse | |||||
| def _has_pad(): | |||||
| res = False | |||||
| input_shape = transdata_op.inputs[0].shape | |||||
| output_shape = transdata_op.output.shape | |||||
| cube_size = 16 | |||||
| for dim in input_shape[-2:]: | |||||
| if dim % cube_size != 0: | |||||
| res = True | |||||
| for dim in output_shape[-2:]: | |||||
| if dim % cube_size != 0: | |||||
| res = True | |||||
| return res | |||||
| has_pad = _has_pad() | |||||
| if has_pad: | |||||
| return False | |||||
| if a.dom_op().prim == "MatMul" and len(dom.ops) == 1: | |||||
| return True | |||||
| return False | |||||
| def _transdata(dom): | |||||
| if dom.dom_op().prim != "TransData": | |||||
| return None | |||||
| fused = [] | |||||
| for a, _ in dom.in_relations.items(): | |||||
| if _transdata_pattern_support(dom, a) and a.check_acyclic(dom): | |||||
| fused.append(a) | |||||
| return fused, True | |||||
| changed = True | changed = True | ||||
| while changed: | while changed: | ||||
| changed = self.fuse(_reshape) | changed = self.fuse(_reshape) | ||||
| @@ -594,6 +627,8 @@ class GraphSplitAscend(GraphSplitByPattern): | |||||
| changed = self.fuse(_broadcast_depth) or changed | changed = self.fuse(_broadcast_depth) or changed | ||||
| changed = self.fuse(_broadcast_width) or changed | changed = self.fuse(_broadcast_width) or changed | ||||
| changed = self.fuse(_matmul_depth) or changed | changed = self.fuse(_matmul_depth) or changed | ||||
| self.fuse(_transdata) | |||||
| def split(graph, target, flags): | def split(graph, target, flags): | ||||
| """Split graph""" | """Split graph""" | ||||
| @@ -186,6 +186,7 @@ class PrimLib: | |||||
| 'Tile': Prim(BROADCAST), | 'Tile': Prim(BROADCAST), | ||||
| 'BroadcastTo': Prim(BROADCAST), | 'BroadcastTo': Prim(BROADCAST), | ||||
| 'MatMul': Prim(OPAQUE), | 'MatMul': Prim(OPAQUE), | ||||
| 'TransData': Prim(OPAQUE), | |||||
| } | } | ||||
| default_primtive = Prim(UNKNOWN) | default_primtive = Prim(UNKNOWN) | ||||
| @@ -596,11 +596,11 @@ std::string ExtractGraphKernelName(const AnfNodePtrList &cnodes, const string &p | |||||
| std::vector<PrimitivePtr> GetFusibleOpList() { | std::vector<PrimitivePtr> GetFusibleOpList() { | ||||
| #if ENABLE_D | #if ENABLE_D | ||||
| std::vector<PrimitivePtr> fusible_basic_ops = { | std::vector<PrimitivePtr> fusible_basic_ops = { | ||||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd, | |||||
| prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | |||||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, | |||||
| prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, | |||||
| prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum}; | |||||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd, | |||||
| prim::kPrimCast, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | |||||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, | |||||
| prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, | |||||
| prim::kPrimRealDiv, prim::kPrimMatMul, prim::kPrimAssign, prim::kPrimReduceSum, prim::KPrimTransData}; | |||||
| #elif ENABLE_GPU | #elif ENABLE_GPU | ||||
| std::vector<PrimitivePtr> fusible_basic_ops = { | std::vector<PrimitivePtr> fusible_basic_ops = { | ||||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd, | prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd, | ||||