diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index 665c0b98ff..82c8d8e322 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -488,7 +488,7 @@ class GraphSplitGpu(GraphSplitByPattern): stitch_tensors = [tensor for tensor in dom_outs if tensor in a_ins] if _same_stitch_axis(stitch_tensors, a_final_outs): for tensor in stitch_tensors: - if _tensor_size(tensor) >= 1024 * 1024 * 12: + if _tensor_size(tensor) >= 1024 * 1024: return True return False diff --git a/tests/st/ops/graph_kernel/test_dropoutgrad_reducesum_stitch.py b/tests/st/ops/graph_kernel/test_dropoutgrad_reducesum_stitch.py new file mode 100644 index 0000000000..3dbc83bc5a --- /dev/null +++ b/tests/st/ops/graph_kernel/test_dropoutgrad_reducesum_stitch.py @@ -0,0 +1,89 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import mindspore.context as context +from mindspore import Tensor +from mindspore.nn import Cell +from mindspore.ops import operations as P +from mindspore.ops.operations import _grad_ops as GP +from mindspore.common import dtype as mstype +import pytest + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") +# enable graph kernel optimization. +context.set_context(enable_graph_kernel=True) + + +class BertAttentionGradPiece(Cell): + def __init__(self): + super(BertAttentionGradPiece, self).__init__() + self.add = P.Add() + self.reducesum = P.ReduceSum(keep_dims=True) + self.dropout_grad = GP.DropoutGrad(1 - 0.1) + self.sub = P.Sub() + self.multiply = P.Mul() + self.cast = P.Cast() + + def construct(self, x, y, z): + out1 = self.dropout_grad(x, y) + out2 = self.multiply(out1, z) + out3 = self.reducesum(self.cast(out2, mstype.float32), (-1,)) + out4 = self.sub(out1, self.cast(out3, mstype.float16)) + return out4 + + +def get_rtol_atol(dtype): + if dtype == np.float16: + return 1.e-3, 1.e-3 + return 1.e-4, 1.e-4 + + +def compare_result(expect, output, dtype): + rtol, atol = get_rtol_atol(dtype) + if isinstance(expect, (list, tuple)): + assert isinstance(output, (list, tuple)) and len(expect) == len(output) + expect_list = list(expect) + output_list = list(output) + for e, o in zip(expect_list, output_list): + assert np.allclose(e.asnumpy(), o.asnumpy(), rtol, atol, equal_nan=True) + else: + assert np.allclose(expect.asnumpy(), output.asnumpy(), rtol, atol, equal_nan=True) + + +def get_dropoutgrad_reducesum_output(x, y, z, enable_stitch_fusion): + # enable graph kernel stitch fusion. + if enable_stitch_fusion: + context.set_context(graph_kernel_flags="--enable_stitch_fusion=true") + net = BertAttentionGradPiece() + result = net(x, y, z) + return result + + +def test_dropoutgrad_reducesum(shape, dtype): + x = Tensor(np.random.normal(0, 1, shape).astype(dtype)) + y = Tensor(np.random.normal(0, 1, shape).astype(dtype)) + z = Tensor(np.random.normal(0, 1, shape).astype(dtype)) + expect = get_dropoutgrad_reducesum_output(x, y, z, False) + output = get_dropoutgrad_reducesum_output(x, y, z, True) + compare_result(expect, output, dtype) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_dropoutgrad_reducesum_gpu(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_dropoutgrad_reducesum([64, 12, 128, 128], np.float16) diff --git a/tests/st/ops/graph_kernel/test_layernorm_stitch.py b/tests/st/ops/graph_kernel/test_layernorm_stitch.py new file mode 100644 index 0000000000..542b307a0f --- /dev/null +++ b/tests/st/ops/graph_kernel/test_layernorm_stitch.py @@ -0,0 +1,86 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import mindspore.context as context +from mindspore import Tensor +import mindspore.nn as nn +from mindspore.nn import Cell +from mindspore.ops import operations as P +import pytest + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") +# enable graph kernel optimization. +context.set_context(enable_graph_kernel=True) + + +class EmbeddingPostprocessor(Cell): + def __init__(self): + super(EmbeddingPostprocessor, self).__init__() + self.layernorm = nn.LayerNorm((768,)) + self.add = P.Add() + self.dropout = nn.Dropout(1 - 0.1) + + def construct(self, word_embeddings, token_type_embeddings, position_embeddings): + output = word_embeddings + output = self.add(output, token_type_embeddings) + output = self.add(output, position_embeddings) + output = self.layernorm(output) + output = self.dropout(output) + return output + + +def get_rtol_atol(dtype): + if dtype == np.float16: + return 1.e-3, 1.e-3 + return 1.e-4, 1.e-4 + + +def compare_result(expect, output, dtype): + rtol, atol = get_rtol_atol(dtype) + if isinstance(expect, (list, tuple)): + assert isinstance(output, (list, tuple)) and len(expect) == len(output) + expect_list = list(expect) + output_list = list(output) + for e, o in zip(expect_list, output_list): + assert np.allclose(e.asnumpy(), o.asnumpy(), rtol, atol, equal_nan=True) + else: + assert np.allclose(expect.asnumpy(), output.asnumpy(), rtol, atol, equal_nan=True) + + +def get_layernorm_output(x, y, z, enable_stitch_fusion): + # enable graph kernel stitch fusion. + if enable_stitch_fusion: + context.set_context(graph_kernel_flags="--enable_stitch_fusion=true") + net = EmbeddingPostprocessor() + result = net(x, y, z) + return result + + +def test_layernorm(shape1, shape2, dtype): + x = Tensor(np.random.normal(0, 1, shape1).astype(dtype)) + y = Tensor(np.random.normal(0, 1, shape1).astype(dtype)) + z = Tensor(np.random.normal(0, 1, shape2).astype(dtype)) + expect = get_layernorm_output(x, y, z, False) + output = get_layernorm_output(x, y, z, True) + compare_result(expect, output, dtype) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_layernorm_gpu(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_layernorm([8192, 768], [1, 768], np.float32) diff --git a/tests/st/ops/graph_kernel/test_softmax_stitch.py b/tests/st/ops/graph_kernel/test_softmax_stitch.py new file mode 100644 index 0000000000..49994ca6e4 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_softmax_stitch.py @@ -0,0 +1,92 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import mindspore.context as context +from mindspore import Tensor +import mindspore.nn as nn +from mindspore.nn import Cell +from mindspore.ops import operations as P +import mindspore.ops.functional as F +import pytest + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") +# enable graph kernel optimization. +context.set_context(enable_graph_kernel=True) + + +class BertAttentionPiece(Cell): + def __init__(self): + super(BertAttentionPiece, self).__init__() + self.add = P.Add() + self.dropout = nn.Dropout(1 - 0.1) + self.softmax = nn.Softmax() + self.multiply_data = -10000.0 + self.sub = P.Sub() + self.multiply = P.Mul() + self.get_dtype = P.DType() + self.cast = P.Cast() + + def construct(self, attention_mask, attention_scores): + multiply_out = self.sub(self.cast(F.tuple_to_array((1.0,)), self.get_dtype(attention_scores)), + self.cast(attention_mask, self.get_dtype(attention_scores))) + adder = self.multiply(multiply_out, self.multiply_data) + attention_scores = self.add(adder, attention_scores) + attention_probs = self.softmax(attention_scores) + attention_probs = self.dropout(attention_probs) + return attention_probs + + +def get_rtol_atol(dtype): + if dtype == np.float16: + return 1.e-3, 1.e-3 + return 1.e-4, 1.e-4 + + +def compare_result(expect, output, dtype): + rtol, atol = get_rtol_atol(dtype) + if isinstance(expect, (list, tuple)): + assert isinstance(output, (list, tuple)) and len(expect) == len(output) + expect_list = list(expect) + output_list = list(output) + for e, o in zip(expect_list, output_list): + assert np.allclose(e.asnumpy(), o.asnumpy(), rtol, atol, equal_nan=True) + else: + assert np.allclose(expect.asnumpy(), output.asnumpy(), rtol, atol, equal_nan=True) + + +def get_softmax_output(x, y, enable_stitch_fusion): + # enable graph kernel stitch fusion. + if enable_stitch_fusion: + context.set_context(graph_kernel_flags="--enable_stitch_fusion=true") + net = BertAttentionPiece() + result = net(x, y) + return result + + +def test_softmax(shape, dtype): + x = Tensor(np.random.normal(0, 1, shape).astype(dtype)) + y = Tensor(np.random.normal(0, 1, shape).astype(dtype)) + expect = get_softmax_output(x, y, False) + output = get_softmax_output(x, y, True) + compare_result(expect, output, dtype) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_softmax_gpu(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_softmax([64, 12, 128, 128], np.float16)