| @@ -466,6 +466,8 @@ class GraphSplitAscend(GraphSplitByPattern): | |||
| REDUCE_FUSE_DEPTH = 10 | |||
| def get_default_mode(self, op): | |||
| if op.prim == "MatMul": | |||
| return self.Area.MODE_COMPOSITE if op.inputs[0].dtype == "float16" else self.Area.MODE_BASIC | |||
| if op.prim in ("Tile", "BroadcastTo"): | |||
| return self.Area.MODE_COMPOSITE | |||
| return self.Area.MODE_BASIC | |||
| @@ -88,8 +88,7 @@ class PrimLib: | |||
| ELEMWISE = 2 | |||
| BROADCAST = 3 | |||
| REDUCE = 4 | |||
| TRANSFORM = 5 | |||
| CONTROL = 6 | |||
| OPAQUE = 5 | |||
| class Prim: | |||
| """Prim""" | |||
| @@ -146,7 +145,6 @@ class PrimLib: | |||
| default_elemwise_broadcast_relation, | |||
| default_reduce_relation, | |||
| unknown_relation, | |||
| unknown_relation, | |||
| ] | |||
| primtives = { | |||
| @@ -176,7 +174,6 @@ class PrimLib: | |||
| 'ReduceSum': Prim(REDUCE), | |||
| 'ReduceMax': Prim(REDUCE), | |||
| 'ReduceMin': Prim(REDUCE), | |||
| 'MakeTuple': Prim(CONTROL), | |||
| 'Assign': Prim(ELEMWISE), | |||
| 'Tanh': Prim(ELEMWISE), | |||
| 'ExpandDims': Prim(RESHAPE), | |||
| @@ -186,9 +183,10 @@ class PrimLib: | |||
| 'Squeeze': Prim(RESHAPE), | |||
| 'Flatten': Prim(RESHAPE), | |||
| 'FlattenGrad': Prim(RESHAPE), | |||
| 'Transpose': Prim(TRANSFORM), | |||
| 'Transpose': Prim(OPAQUE), | |||
| 'Tile': Prim(BROADCAST), | |||
| 'BroadcastTo': Prim(BROADCAST), | |||
| 'MatMul': Prim(OPAQUE), | |||
| } | |||
| default_primtive = Prim(UNKNOWN) | |||
| @@ -509,7 +507,7 @@ class AddControlBuddy(GraphVisitor): | |||
| self.buddies = {} # {op : [ctrl_op]} | |||
| def visit(self, op): | |||
| if PrimLib.iter_type(op) == PrimLib.CONTROL: | |||
| if op.prim == "MakeTuple": | |||
| assert len(op.output.to_ops) == 1 | |||
| owner = op.output.to_ops[0] | |||
| if owner in self.buddies: | |||
| @@ -177,6 +177,18 @@ void SetAkgAttrsForBN2Relu(const AnfNodePtr &anf_node) { | |||
| AnfAlgo::SetNodeAttr(kAttrOutputNames, MakeValue(bn2_output_names), anf_node); | |||
| } | |||
| void SetAkgAttrsForMatMul(const AnfNodePtr &anf_node) { | |||
| MS_EXCEPTION_IF_NULL(anf_node); | |||
| std::string dst_type; | |||
| TypeId output_type = AnfAlgo::GetOutputDeviceDataType(anf_node, 0); | |||
| dst_type = TypeId2String(output_type); | |||
| AnfAlgo::SetNodeAttr("dst_type", MakeValue(dst_type), anf_node); | |||
| auto left_format = AnfAlgo::GetInputFormat(anf_node, 0); | |||
| auto right_format = AnfAlgo::GetInputFormat(anf_node, 1); | |||
| AnfAlgo::SetNodeAttr("left_format", MakeValue(left_format), anf_node); | |||
| AnfAlgo::SetNodeAttr("right_format", MakeValue(right_format), anf_node); | |||
| } | |||
| const std::unordered_map<std::string, std::function<void(const AnfNodePtr &anf_node)>> kAkgKernelAttrsProcessMap = { | |||
| {kFour2FiveOpName, SetAkgAttrsForFour2Five}, | |||
| {kFive2FourOpName, SetAkgAttrsForFive2Four}, | |||
| @@ -190,6 +202,7 @@ const std::unordered_map<std::string, std::function<void(const AnfNodePtr &anf_n | |||
| {kConvBN1OpName, SetAkgAttrsForConvBN1}, | |||
| {kBN2AddReluOpName, SetAkgAttrsForBN2AddRelu}, | |||
| {kBN2ReLUOpName, SetAkgAttrsForBN2Relu}, | |||
| {kMatMulOpName, SetAkgAttrsForMatMul}, | |||
| }; | |||
| } // namespace | |||
| @@ -575,7 +575,7 @@ std::vector<PrimitivePtr> GetFusibleOpList() { | |||
| prim::kPrimExpandDims, prim::kPrimMul, prim::kPrimMinimum, prim::kPrimMaximum, prim::kPrimLog, | |||
| prim::kPrimPow, prim::kPrimSub, prim::kPrimRsqrt, prim::kPrimSqrt, prim::kPrimAddN, | |||
| prim::kPrimEqual, prim::kPrimReciprocal, prim::kPrimTanh, prim::kPrimReshape, prim::kPrimTranspose, | |||
| prim::kPrimCast, prim::kPrimRealDiv}; | |||
| prim::kPrimCast, prim::kPrimRealDiv, prim::kPrimMatMul}; | |||
| #elif ENABLE_GPU | |||
| std::vector<PrimitivePtr> fusible_basic_ops = { | |||
| prim::kPrimAbs, prim::kPrimRound, prim::kPrimNeg, prim::kPrimExp, prim::kPrimAdd, | |||
| @@ -265,6 +265,7 @@ constexpr auto kSGDName = "SGD"; | |||
| constexpr auto kLARSUpdateName = "LARSUpdate"; | |||
| constexpr auto kBasicLSTMCellCStateGradOpName = "BasicLSTMCellCStateGrad"; | |||
| constexpr auto kBasicLSTMCellCStateGradV2OpName = "BasicLSTMCellCStateGradV2"; | |||
| constexpr auto kMatMulOpName = "MatMul"; | |||
| constexpr auto kMatMulV2OpName = "MatMulV2"; | |||
| constexpr auto kBroadcastToOpName = "BroadcastTo"; | |||
| constexpr auto kFusedAddReluV2Name = "FusedAddReluV2"; | |||
| @@ -0,0 +1,88 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| import numpy as np | |||
| import pytest | |||
| import mindspore.context as context | |||
| from mindspore import Tensor | |||
| from mindspore.nn import Cell | |||
| import mindspore.ops.operations as P | |||
| class Net(Cell): | |||
| def __init__(self): | |||
| super(Net, self).__init__() | |||
| self.matmul = P.MatMul(transpose_a=True, transpose_b=True) | |||
| def construct(self, x, y): | |||
| return self.matmul(x, y) | |||
| class Net1(Cell): | |||
| def __init__(self): | |||
| super(Net1, self).__init__() | |||
| self.matmul = P.MatMul(transpose_a=True, transpose_b=True) | |||
| self.add = P.BiasAdd() | |||
| def construct(self, x, y, bias): | |||
| res = self.matmul(x, y) | |||
| return self.add(res, bias) | |||
| def get_output(i0, i1, enable_graph_kernel=False): | |||
| if enable_graph_kernel: | |||
| context.set_context(enable_graph_kernel=True, save_graphs=False) | |||
| net = Net() | |||
| output = net(i0, i1) | |||
| return output | |||
| def get_output1(i0, i1, i2, enable_graph_kernel=False): | |||
| if enable_graph_kernel: | |||
| context.set_context(enable_graph_kernel=True, save_graphs=False) | |||
| net = Net1() | |||
| output = net(i0, i1, i2) | |||
| return output | |||
| def test_basic(): | |||
| i0 = Tensor(np.random.normal(1, 0.01, [800, 96]).astype(np.float16)) | |||
| i1 = Tensor(np.random.normal(1, 0.01, [128, 800]).astype(np.float16)) | |||
| expect = get_output(i0, i1, False) | |||
| output = get_output(i0, i1, True) | |||
| expect_np = expect.asnumpy().copy() | |||
| output_np = output.asnumpy().copy() | |||
| assert np.allclose(expect_np, output_np, 1.e-4, 1.e-7) | |||
| def test_basic1(): | |||
| i0 = Tensor(np.random.normal(1, 0.01, [800, 96]).astype(np.float16)) | |||
| i1 = Tensor(np.random.normal(1, 0.01, [128, 800]).astype(np.float16)) | |||
| i2 = Tensor(np.random.normal(100, 0.01, [128,]).astype(np.float16)) | |||
| expect = get_output1(i0, i1, i2, False) | |||
| output = get_output1(i0, i1, i2, True) | |||
| expect_np = expect.asnumpy().copy() | |||
| output_np = output.asnumpy().copy() | |||
| assert np.allclose(expect_np, output_np, 6.e-4, 6.e-4) | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_basic_ascend(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| test_basic() | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_arm_ascend_training | |||
| @pytest.mark.platform_x86_ascend_training | |||
| @pytest.mark.env_onecard | |||
| def test_basic_ascend1(): | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") | |||
| test_basic1() | |||