| @@ -466,6 +466,8 @@ class GraphSplitAscend(GraphSplitByPattern): | |||||
| REDUCE_FUSE_DEPTH = 10 | REDUCE_FUSE_DEPTH = 10 | ||||
| def get_default_mode(self, op): | def get_default_mode(self, op): | ||||
| if op.prim == "MatMul": | |||||
| return self.Area.MODE_COMPOSITE if op.inputs[0].dtype == "float16" else self.Area.MODE_BASIC | |||||
| if op.prim in ("Tile", "BroadcastTo"): | if op.prim in ("Tile", "BroadcastTo"): | ||||
| return self.Area.MODE_COMPOSITE | return self.Area.MODE_COMPOSITE | ||||
| return self.Area.MODE_BASIC | return self.Area.MODE_BASIC | ||||
| @@ -88,8 +88,7 @@ class PrimLib: | |||||
| ELEMWISE = 2 | ELEMWISE = 2 | ||||
| BROADCAST = 3 | BROADCAST = 3 | ||||
| REDUCE = 4 | REDUCE = 4 | ||||
| TRANSFORM = 5 | |||||
| CONTROL = 6 | |||||
| OPAQUE = 5 | |||||
| class Prim: | class Prim: | ||||
| """Prim""" | """Prim""" | ||||
| @@ -146,7 +145,6 @@ class PrimLib: | |||||
| default_elemwise_broadcast_relation, | default_elemwise_broadcast_relation, | ||||
| default_reduce_relation, | default_reduce_relation, | ||||
| unknown_relation, | unknown_relation, | ||||
| unknown_relation, | |||||
| ] | ] | ||||
| primtives = { | primtives = { | ||||
| @@ -176,7 +174,6 @@ class PrimLib: | |||||
| 'ReduceSum': Prim(REDUCE), | 'ReduceSum': Prim(REDUCE), | ||||
| 'ReduceMax': Prim(REDUCE), | 'ReduceMax': Prim(REDUCE), | ||||
| 'ReduceMin': Prim(REDUCE), | 'ReduceMin': Prim(REDUCE), | ||||
| 'MakeTuple': Prim(CONTROL), | |||||
| 'Assign': Prim(ELEMWISE), | 'Assign': Prim(ELEMWISE), | ||||
| 'Tanh': Prim(ELEMWISE), | 'Tanh': Prim(ELEMWISE), | ||||
| 'ExpandDims': Prim(RESHAPE), | 'ExpandDims': Prim(RESHAPE), | ||||
| @@ -186,9 +183,10 @@ class PrimLib: | |||||
| 'Squeeze': Prim(RESHAPE), | 'Squeeze': Prim(RESHAPE), | ||||
| 'Flatten': Prim(RESHAPE), | 'Flatten': Prim(RESHAPE), | ||||
| 'FlattenGrad': Prim(RESHAPE), | 'FlattenGrad': Prim(RESHAPE), | ||||
| 'Transpose': Prim(TRANSFORM), | |||||
| 'Transpose': Prim(OPAQUE), | |||||
| 'Tile': Prim(BROADCAST), | 'Tile': Prim(BROADCAST), | ||||
| 'BroadcastTo': Prim(BROADCAST), | 'BroadcastTo': Prim(BROADCAST), | ||||
| 'MatMul': Prim(OPAQUE), | |||||
| } | } | ||||
| default_primtive = Prim(UNKNOWN) | default_primtive = Prim(UNKNOWN) | ||||
| @@ -509,7 +507,7 @@ class AddControlBuddy(GraphVisitor): | |||||
| self.buddies = {} # {op : [ctrl_op]} | self.buddies = {} # {op : [ctrl_op]} | ||||
| def visit(self, op): | def visit(self, op): | ||||
| if PrimLib.iter_type(op) == PrimLib.CONTROL: | |||||
| if op.prim == "MakeTuple": | |||||
| assert len(op.output.to_ops) == 1 | assert len(op.output.to_ops) == 1 | ||||
| owner = op.output.to_ops[0] | owner = op.output.to_ops[0] | ||||
| if owner in self.buddies: | if owner in self.buddies: | ||||
| @@ -177,6 +177,18 @@ void SetAkgAttrsForBN2Relu(const AnfNodePtr &anf_node) { | |||||
| AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(bn2_output_names), anf_node); | AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(bn2_output_names), anf_node); | ||||
| } | } | ||||
| void SetAkgAttrsForMatMul(const AnfNodePtr &anf_node) { | |||||
| MS_EXCEPTION_IF_NULL(anf_node); | |||||
| std::string dst_type; | |||||
| TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); | |||||
| dst_type = TypeId2String(output_type); | |||||
| AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node); | |||||
| auto left_format = AnfAlgo::GetInputFormat(anf_node, 0); | |||||
| auto right_format = AnfAlgo::GetInputFormat(anf_node, 1); | |||||
| AnfAlgo::SetNodeAttr("left_format", MakeValue(left_format), anf_node); | |||||
| AnfAlgo::SetNodeAttr("right_format", MakeValue(right_format), anf_node); | |||||
| } | |||||
| const std::unordered_map<std::string, std::function<void(const AnfNodePtr &anf_node)>> kAkgKernelAttrsProcessMap = { | const std::unordered_map<std::string, std::function<void(const AnfNodePtr &anf_node)>> kAkgKernelAttrsProcessMap = { | ||||
| {kFour2FiveOpName, SetAkgAttrsForFour2Five}, | {kFour2FiveOpName, SetAkgAttrsForFour2Five}, | ||||
| {kFive2FourOpName, SetAkgAttrsForFive2Four}, | {kFive2FourOpName, SetAkgAttrsForFive2Four}, | ||||
| @@ -190,6 +202,7 @@ const std::unordered_map<std::string, std::function<void(const AnfNodePtr &anf_n | |||||
| {kConvBN1OpName, SetAkgAttrsForConvBN1}, | {kConvBN1OpName, SetAkgAttrsForConvBN1}, | ||||
| {kBN2AddReluOpName, SetAkgAttrsForBN2AddRelu}, | {kBN2AddReluOpName, SetAkgAttrsForBN2AddRelu}, | ||||
| {kBN2ReLUOpName, SetAkgAttrsForBN2Relu}, | {kBN2ReLUOpName, SetAkgAttrsForBN2Relu}, | ||||
| {kMatMulOpName, SetAkgAttrsForMatMul}, | |||||
| }; | }; | ||||
| } // namespace | } // namespace | ||||
| @@ -575,7 +575,7 @@ std::vector<PrimitivePtr> GetFusibleOpList() { | |||||
| prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | ||||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, | prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, | ||||
| prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, | prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, | ||||
| prim::kPrimCast, prim::kPrimRealDiv}; | |||||
| prim::kPrimCast, prim::kPrimRealDiv, prim::kPrimMatMul}; | |||||
| #elif ENABLE_GPU | #elif ENABLE_GPU | ||||
| std::vector<PrimitivePtr> fusible_basic_ops = { | std::vector<PrimitivePtr> fusible_basic_ops = { | ||||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd, | prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd, | ||||
| @@ -265,6 +265,7 @@ constexpr auto kSGDName = "SGD"; | |||||
| constexpr auto kLARSUpdateName = "LARSUpdate"; | constexpr auto kLARSUpdateName = "LARSUpdate"; | ||||
| constexpr auto kBasicLSTMCellCStateGradOpName = "BasicLSTMCellCStateGrad"; | constexpr auto kBasicLSTMCellCStateGradOpName = "BasicLSTMCellCStateGrad"; | ||||
| constexpr auto kBasicLSTMCellCStateGradV2OpName = "BasicLSTMCellCStateGradV2"; | constexpr auto kBasicLSTMCellCStateGradV2OpName = "BasicLSTMCellCStateGradV2"; | ||||
| constexpr auto kMatMulOpName = "MatMul"; | |||||
| constexpr auto kMatMulV2OpName = "MatMulV2"; | constexpr auto kMatMulV2OpName = "MatMulV2"; | ||||
| constexpr auto kBroadcastToOpName = "BroadcastTo"; | constexpr auto kBroadcastToOpName = "BroadcastTo"; | ||||
| constexpr auto kFusedAddReluV2Name = "FusedAddReluV2"; | constexpr auto kFusedAddReluV2Name = "FusedAddReluV2"; | ||||
| @@ -0,0 +1,88 @@ | |||||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| import numpy as np | |||||
| import pytest | |||||
| import mindspore.context as context | |||||
| from mindspore import Tensor | |||||
| from mindspore.nn import Cell | |||||
| import mindspore.ops.operations as P | |||||
| class Net(Cell): | |||||
| def __init__(self): | |||||
| super(Net, self).__init__() | |||||
| self.matmul = P.MatMul(transpose_a=True, transpose_b=True) | |||||
| def construct(self, x, y): | |||||
| return self.matmul(x, y) | |||||
| class Net1(Cell): | |||||
| def __init__(self): | |||||
| super(Net1, self).__init__() | |||||
| self.matmul = P.MatMul(transpose_a=True, transpose_b=True) | |||||
| self.add = P.BiasAdd() | |||||
| def construct(self, x, y, bias): | |||||
| res = self.matmul(x, y) | |||||
| return self.add(res, bias) | |||||
| def get_output(i0, i1, enable_graph_kernel=False): | |||||
| if enable_graph_kernel: | |||||
| context.set_context(enable_graph_kernel=True, save_graphs=False) | |||||
| net = Net() | |||||
| output = net(i0, i1) | |||||
| return output | |||||
| def get_output1(i0, i1, i2, enable_graph_kernel=False): | |||||
| if enable_graph_kernel: | |||||
| context.set_context(enable_graph_kernel=True, save_graphs=False) | |||||
| net = Net1() | |||||
| output = net(i0, i1, i2) | |||||
| return output | |||||
| def test_basic(): | |||||
| i0 = Tensor(np.random.normal(1, 0.01, [800, 96]).astype(np.float16)) | |||||
| i1 = Tensor(np.random.normal(1, 0.01, [128, 800]).astype(np.float16)) | |||||
| expect = get_output(i0, i1, False) | |||||
| output = get_output(i0, i1, True) | |||||
| expect_np = expect.asnumpy().copy() | |||||
| output_np = output.asnumpy().copy() | |||||
| assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7) | |||||
| def test_basic1(): | |||||
| i0 = Tensor(np.random.normal(1, 0.01, [800, 96]).astype(np.float16)) | |||||
| i1 = Tensor(np.random.normal(1, 0.01, [128, 800]).astype(np.float16)) | |||||
| i2 = Tensor(np.random.normal(100, 0.01, [128,]).astype(np.float16)) | |||||
| expect = get_output1(i0, i1, i2, False) | |||||
| output = get_output1(i0, i1, i2, True) | |||||
| expect_np = expect.asnumpy().copy() | |||||
| output_np = output.asnumpy().copy() | |||||
| assert np.allclose(expect_np, output_np, 6.e-4, 6.e-4) | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_basic_ascend(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| test_basic() | |||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_basic_ascend1(): | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||||
| test_basic1() | |||||