You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cus_matmul_dds.py 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore as ms
  17. import mindspore.nn as nn
  18. from mindspore import Tensor
  19. from mindspore import context
  20. from mindspore.common.api import _cell_graph_executor
  21. from mindspore.ops import composite as C
  22. from mindspore.ops import operations as P
  23. from mindspore.parallel import set_algo_parameters
  24. from mindspore.ops.operations._inner_ops import MatmulDDS
  25. from tests.ut.python.ops.test_math_ops import VirtualLoss
  26. context.set_context(mode=context.GRAPH_MODE)
  27. grad_all = C.GradOperation(get_all=True)
  28. # q: (num_heads * size_per_head // 16, bs * seq_len // 16, 16, 16)
  29. # k: (num_heads * size_per_head // 16, bs * seq_len // 16, 16, 16)
  30. # local_mask: (block_num * block_size // 16, bs * block_size // 16, 16, 16)
  31. # global_mask: (bs * global_size // 16, seq_len // 16, 16, 16)
  32. # local_prob: (bs, num_heads, block_num, block_size // 16, block_size // 16, 16, 16)
  33. # global_prob: (bs, num_heads, block_num, global_size // 16, block_size // 16, 16, 16)
  34. # x: (bs*seq_len, num_heads*size_per_head)
  35. class Net(nn.Cell):
  36. def __init__(self, batch_size, num_heads, dp, mp, shard=True):
  37. super(Net, self).__init__()
  38. self.batch_size = batch_size
  39. self.num_heads = num_heads
  40. self.size_per_head = 128
  41. self.seq_len = 1024
  42. self.block_size = 64
  43. self.block_num = self.seq_len // self.block_size
  44. self.global_size = 256
  45. self.embedding_size = num_heads * self.size_per_head
  46. self.cus_matmul = MatmulDDS(batch_size, num_heads)
  47. self.reduce_sum = P.ReduceSum()
  48. self.global_mask = Tensor(np.ones((batch_size * self.global_size // 16, self.seq_len // 16, 16, 16)))
  49. self.local_mask = Tensor(np.ones((self.block_num * self.block_size // 16,
  50. batch_size * self.block_size // 16, 16, 16)))
  51. self.dense1 = nn.Dense(self.embedding_size, self.embedding_size, has_bias=False)
  52. self.dense2 = nn.Dense(self.embedding_size, self.embedding_size, has_bias=False)
  53. self.reshape = P.Reshape()
  54. self.transpose = P.Transpose()
  55. self.add = P.Add()
  56. if shard:
  57. self.cus_matmul.shard(((mp, dp, 1, 1), (mp, dp, 1, 1), (1, dp, 1, 1), (dp, 1, 1, 1)))
  58. self.dense1.matmul.shard(((dp, 1), (mp, 1)))
  59. self.dense2.matmul.shard(((dp, 1), (mp, 1)))
  60. self.transpose.shard(((dp, 1, mp, 1),))
  61. def construct(self, x):
  62. q = self.dense1(x)
  63. k = self.dense2(x)
  64. q = self.transpose(self.reshape(q, (-1, 16, self.embedding_size // 16, 16)), (2, 0, 1, 3))
  65. k = self.transpose(self.reshape(k, (-1, 16, self.embedding_size // 16, 16)), (2, 0, 1, 3))
  66. local_prob, global_prob = self.cus_matmul(q, k, self.local_mask, self.global_mask)
  67. local_prob = self.reshape(local_prob, (self.batch_size, self.num_heads, -1))
  68. global_prob = self.reshape(global_prob, (self.batch_size, self.num_heads, -1))
  69. local_prob_reduce = self.reduce_sum(local_prob, 2)
  70. global_prob_reduce = self.reduce_sum(global_prob, 2)
  71. result = self.add(local_prob_reduce, global_prob_reduce)
  72. return result
  73. class GradWrap(nn.Cell):
  74. def __init__(self, network):
  75. super(GradWrap, self).__init__()
  76. self.network = network
  77. def construct(self, x):
  78. return grad_all(self.network)(x)
  79. class NetWithLoss(nn.Cell):
  80. def __init__(self, network):
  81. super(NetWithLoss, self).__init__()
  82. self.network = network
  83. self.loss = VirtualLoss()
  84. def construct(self, x):
  85. predict = self.network(x)
  86. return self.loss(predict)
  87. def compile_graph(batch_size, num_heads, dp, mp, auto=False, shard=True):
  88. if auto:
  89. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  90. else:
  91. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  92. x = Tensor(np.ones((batch_size * 1024, num_heads * 128)), ms.float32)
  93. net = GradWrap(NetWithLoss(Net(batch_size, num_heads, dp, mp, shard=shard)))
  94. net.set_auto_parallel()
  95. net.set_train()
  96. _cell_graph_executor.compile(net, x)
  97. def test_cus_matmul_dds_model_parallel_mix():
  98. context.set_auto_parallel_context(device_num=16, global_rank=0)
  99. batch_size = 128
  100. num_heads = 32
  101. dp = 2
  102. mp = 8
  103. compile_graph(batch_size, num_heads, dp, mp)
  104. def test_cus_matmul_dds_model_parallel_dp():
  105. context.set_auto_parallel_context(device_num=16, global_rank=0)
  106. batch_size = 128
  107. num_heads = 32
  108. dp = 16
  109. mp = 1
  110. compile_graph(batch_size, num_heads, dp, mp)
  111. def test_cus_matmul_dds_model_parallel_mp():
  112. context.set_auto_parallel_context(device_num=16, global_rank=0)
  113. batch_size = 128
  114. num_heads = 32
  115. dp = 1
  116. mp = 16
  117. compile_graph(batch_size, num_heads, dp, mp)
  118. def test_cus_matmul_dds_model_parallel_mix_auto():
  119. set_algo_parameters(fully_use_devices=False)
  120. context.set_auto_parallel_context(device_num=16, global_rank=0)
  121. batch_size = 128
  122. num_heads = 32
  123. dp = 2
  124. mp = 8
  125. compile_graph(batch_size, num_heads, dp, mp, auto=True)
  126. def test_cus_matmul_dds_model_parallel_dp_auto():
  127. context.set_auto_parallel_context(device_num=16, global_rank=0)
  128. batch_size = 128
  129. num_heads = 32
  130. dp = 16
  131. mp = 1
  132. compile_graph(batch_size, num_heads, dp, mp, auto=True)
  133. def test_cus_matmul_dds_model_parallel_mp_auto():
  134. context.set_auto_parallel_context(device_num=16, global_rank=0)
  135. batch_size = 128
  136. num_heads = 32
  137. dp = 1
  138. mp = 16
  139. compile_graph(batch_size, num_heads, dp, mp, auto=True)
  140. def test_cus_matmul_dds_model_parallel_auto():
  141. set_algo_parameters(fully_use_devices=False)
  142. context.set_auto_parallel_context(device_num=16, global_rank=0)
  143. batch_size = 128
  144. num_heads = 32
  145. dp = 1
  146. mp = 16
  147. compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False)
  148. def test_cus_matmul_dds_repeat_cal_auto():
  149. set_algo_parameters(fully_use_devices=False)
  150. context.set_auto_parallel_context(device_num=16, global_rank=0)
  151. batch_size = 128
  152. num_heads = 32
  153. dp = 1
  154. mp = 2
  155. compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False)
  156. def test_cus_matmul_dds_repeat1_cal_auto():
  157. set_algo_parameters(fully_use_devices=False)
  158. context.set_auto_parallel_context(device_num=16, global_rank=0)
  159. batch_size = 128
  160. num_heads = 32
  161. dp = 2
  162. mp = 1
  163. compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False)