|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- import numpy as np
-
- import mindspore as ms
- import mindspore.nn as nn
- from mindspore import Tensor
- from mindspore import context
- from mindspore.common.api import _cell_graph_executor
- from mindspore.ops import composite as C
- from mindspore.ops import operations as P
- from mindspore.parallel import set_algo_parameters
- from mindspore.ops.operations._inner_ops import MatmulDDS
- from tests.ut.python.ops.test_math_ops import VirtualLoss
-
- context.set_context(mode=context.GRAPH_MODE)
-
- grad_all = C.GradOperation(get_all=True)
-
- # q: (num_heads * size_per_head // 16, bs * seq_len // 16, 16, 16)
- # k: (num_heads * size_per_head // 16, bs * seq_len // 16, 16, 16)
- # local_mask: (block_num * block_size // 16, bs * block_size // 16, 16, 16)
- # global_mask: (bs * global_size // 16, seq_len // 16, 16, 16)
- # local_prob: (bs, num_heads, block_num, block_size // 16, block_size // 16, 16, 16)
- # global_prob: (bs, num_heads, block_num, global_size // 16, block_size // 16, 16, 16)
- # x: (bs*seq_len, num_heads*size_per_head)
- class Net(nn.Cell):
- def __init__(self, batch_size, num_heads, dp, mp, shard=True):
- super(Net, self).__init__()
- self.batch_size = batch_size
- self.num_heads = num_heads
- self.size_per_head = 128
- self.seq_len = 1024
- self.block_size = 64
- self.block_num = self.seq_len // self.block_size
- self.global_size = 256
- self.embedding_size = num_heads * self.size_per_head
- self.cus_matmul = MatmulDDS(batch_size, num_heads)
- self.reduce_sum = P.ReduceSum()
- self.global_mask = Tensor(np.ones((batch_size * self.global_size // 16, self.seq_len // 16, 16, 16)))
- self.local_mask = Tensor(np.ones((self.block_num * self.block_size // 16,
- batch_size * self.block_size // 16, 16, 16)))
- self.dense1 = nn.Dense(self.embedding_size, self.embedding_size, has_bias=False)
- self.dense2 = nn.Dense(self.embedding_size, self.embedding_size, has_bias=False)
- self.reshape = P.Reshape()
- self.transpose = P.Transpose()
- self.add = P.Add()
- if shard:
- self.cus_matmul.shard(((mp, dp, 1, 1), (mp, dp, 1, 1), (1, dp, 1, 1), (dp, 1, 1, 1)))
- self.dense1.matmul.shard(((dp, 1), (mp, 1)))
- self.dense2.matmul.shard(((dp, 1), (mp, 1)))
- self.transpose.shard(((dp, 1, mp, 1),))
-
-
- def construct(self, x):
- q = self.dense1(x)
- k = self.dense2(x)
- q = self.transpose(self.reshape(q, (-1, 16, self.embedding_size // 16, 16)), (2, 0, 1, 3))
- k = self.transpose(self.reshape(k, (-1, 16, self.embedding_size // 16, 16)), (2, 0, 1, 3))
- local_prob, global_prob = self.cus_matmul(q, k, self.local_mask, self.global_mask)
- local_prob = self.reshape(local_prob, (self.batch_size, self.num_heads, -1))
- global_prob = self.reshape(global_prob, (self.batch_size, self.num_heads, -1))
- local_prob_reduce = self.reduce_sum(local_prob, 2)
- global_prob_reduce = self.reduce_sum(global_prob, 2)
- result = self.add(local_prob_reduce, global_prob_reduce)
- return result
-
-
- class GradWrap(nn.Cell):
- def __init__(self, network):
- super(GradWrap, self).__init__()
- self.network = network
-
- def construct(self, x):
- return grad_all(self.network)(x)
-
-
- class NetWithLoss(nn.Cell):
- def __init__(self, network):
- super(NetWithLoss, self).__init__()
- self.network = network
- self.loss = VirtualLoss()
-
- def construct(self, x):
- predict = self.network(x)
- return self.loss(predict)
-
-
- def compile_graph(batch_size, num_heads, dp, mp, auto=False, shard=True):
- if auto:
- context.set_auto_parallel_context(parallel_mode="auto_parallel")
- else:
- context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
- x = Tensor(np.ones((batch_size * 1024, num_heads * 128)), ms.float32)
- net = GradWrap(NetWithLoss(Net(batch_size, num_heads, dp, mp, shard=shard)))
- net.set_auto_parallel()
- net.set_train()
- _cell_graph_executor.compile(net, x)
-
- def test_cus_matmul_dds_model_parallel_mix():
- context.set_auto_parallel_context(device_num=16, global_rank=0)
- batch_size = 128
- num_heads = 32
- dp = 2
- mp = 8
- compile_graph(batch_size, num_heads, dp, mp)
-
- def test_cus_matmul_dds_model_parallel_dp():
- context.set_auto_parallel_context(device_num=16, global_rank=0)
- batch_size = 128
- num_heads = 32
- dp = 16
- mp = 1
- compile_graph(batch_size, num_heads, dp, mp)
-
- def test_cus_matmul_dds_model_parallel_mp():
- context.set_auto_parallel_context(device_num=16, global_rank=0)
- batch_size = 128
- num_heads = 32
- dp = 1
- mp = 16
- compile_graph(batch_size, num_heads, dp, mp)
-
- def test_cus_matmul_dds_model_parallel_mix_auto():
- set_algo_parameters(fully_use_devices=False)
- context.set_auto_parallel_context(device_num=16, global_rank=0)
- batch_size = 128
- num_heads = 32
- dp = 2
- mp = 8
- compile_graph(batch_size, num_heads, dp, mp, auto=True)
-
- def test_cus_matmul_dds_model_parallel_dp_auto():
- context.set_auto_parallel_context(device_num=16, global_rank=0)
- batch_size = 128
- num_heads = 32
- dp = 16
- mp = 1
- compile_graph(batch_size, num_heads, dp, mp, auto=True)
-
- def test_cus_matmul_dds_model_parallel_mp_auto():
- context.set_auto_parallel_context(device_num=16, global_rank=0)
- batch_size = 128
- num_heads = 32
- dp = 1
- mp = 16
- compile_graph(batch_size, num_heads, dp, mp, auto=True)
-
- def test_cus_matmul_dds_model_parallel_auto():
- set_algo_parameters(fully_use_devices=False)
- context.set_auto_parallel_context(device_num=16, global_rank=0)
- batch_size = 128
- num_heads = 32
- dp = 1
- mp = 16
- compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False)
-
- def test_cus_matmul_dds_repeat_cal_auto():
- set_algo_parameters(fully_use_devices=False)
- context.set_auto_parallel_context(device_num=16, global_rank=0)
- batch_size = 128
- num_heads = 32
- dp = 1
- mp = 2
- compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False)
-
- def test_cus_matmul_dds_repeat1_cal_auto():
- set_algo_parameters(fully_use_devices=False)
- context.set_auto_parallel_context(device_num=16, global_rank=0)
- batch_size = 128
- num_heads = 32
- dp = 2
- mp = 1
- compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False)
|