You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dsd_matmul.py 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore as ms
  17. import mindspore.nn as nn
  18. from mindspore import Tensor
  19. from mindspore import context
  20. from mindspore.common.api import _cell_graph_executor
  21. from mindspore.ops import composite as C
  22. from mindspore.ops import operations as P
  23. from mindspore.parallel import set_algo_parameters
  24. from mindspore.ops.operations._inner_ops import DSDMatmul
  25. from tests.ut.python.ops.test_math_ops import VirtualLoss
  26. context.set_context(mode=context.GRAPH_MODE)
  27. grad_all = C.GradOperation(get_all=True)
  28. # input_w1, the shape is (batch_size, head, block_num, head_size // 16, block_size//16, 16, 16)
  29. # input_w1 cum_shape = batch_size * seq_len * embedding_size * (block_size // size_per_head)
  30. # = batch_size * seq_len * (embedding_size // 2)
  31. # input_w2, the shape is (batch_size, head, block_num, global_size // 16, head_size // 16, 16, 16)
  32. # input_w2 cum_shape = batch_size * seq_len * embedding_size * (global_size // size_per_head)
  33. # = batch_size * seq_len * embedding_size * 2
  34. # input_v, the shape is (batch_size * seq_len // 16, head * v_embedding // 16, 16, 16)
  35. # block_num = seq_len // block_size, block_size = 64, head * v_embedding = embedding_size, always.
  36. # output shape is (batch_size, head, v_embedding // 16, seq_len//16, 16, 16)
  37. class Net(nn.Cell):
  38. def __init__(self, batch_size, num_heads, dp, mp, shard=True):
  39. super(Net, self).__init__()
  40. self.batch_size = batch_size
  41. self.num_heads = num_heads
  42. self.seq_len = 1024
  43. self.block_size = 64
  44. self.head_size = self.block_size
  45. self.block_num = self.seq_len // self.block_size
  46. self.global_size = 256
  47. self.v_embedding = 128
  48. self.embedding_size = num_heads * self.v_embedding
  49. self.dsd_matmul = DSDMatmul()
  50. self.reduce_sum = P.ReduceSum()
  51. self.dense1 = nn.Dense(self.embedding_size, self.embedding_size // 2, has_bias=False)
  52. self.dense2 = nn.Dense(self.embedding_size, self.embedding_size * 2, has_bias=False)
  53. self.dense3 = nn.Dense(self.embedding_size, self.embedding_size, has_bias=False)
  54. self.reshape = P.Reshape()
  55. self.transpose = P.Transpose()
  56. self.transpose1 = P.Transpose()
  57. self.add = P.Add()
  58. if shard:
  59. self.dsd_matmul.shard(((dp, mp, 1, 1, 1, 1, 1), (dp, mp, 1, 1, 1, 1, 1), (dp, mp, 1, 1)))
  60. self.dense1.matmul.shard(((dp, 1), (mp, 1)))
  61. self.dense2.matmul.shard(((dp, 1), (mp, 1)))
  62. self.dense2.matmul.shard(((dp, 1), (mp, 1)))
  63. self.transpose.shard(((dp, 1, mp, 1),))
  64. self.transpose1.shard(((dp, mp, 1, 1, 1, 1),))
  65. def construct(self, x):
  66. # x (batch_size * seq_len, embedding_size)
  67. q = self.dense1(x)
  68. # q (batch_size * seq_len, (embedding_size // 2))
  69. # (batch_size, head, block_num, head_size // 16, block_size//16, 16, 16)
  70. k = self.dense2(x)
  71. # k (batch_size * seq_len, (embedding_size * 2))
  72. # (batch_size, head, block_num, global_size // 16, head_size // 16, 16, 16)
  73. v = self.dense3(x)
  74. # v (batch_size * seq_len, embedding_size)
  75. q = self.reshape(q, (self.batch_size, self.num_heads, self.block_num, self.head_size // 16,
  76. self.block_size // 16, 16, 16))
  77. k = self.reshape(k, (self.batch_size, self.num_heads, self.block_num, self.global_size // 16,
  78. self.head_size // 16, 16, 16))
  79. v = self.transpose(self.reshape(v, (-1, 16, self.embedding_size // 16, 16)), (0, 2, 3, 1))
  80. dsd = self.dsd_matmul(q, k, v)
  81. # dsd (batch_size, head, v_embedding // 16, seq_len//16, 16, 16)
  82. dsd = self.transpose1(dsd, (0, 1, 3, 4, 2, 5))
  83. # dsd (batch_size, head, seq_len//16, 16, v_embedding_size//16, 16)
  84. dsd = self.reshape(dsd, (-1, self.seq_len, self.v_embedding * self.num_heads))
  85. result = self.reduce_sum(dsd, 2)
  86. return result
  87. class GradWrap(nn.Cell):
  88. def __init__(self, network):
  89. super(GradWrap, self).__init__()
  90. self.network = network
  91. def construct(self, x):
  92. return grad_all(self.network)(x)
  93. class NetWithLoss(nn.Cell):
  94. def __init__(self, network):
  95. super(NetWithLoss, self).__init__()
  96. self.network = network
  97. self.loss = VirtualLoss()
  98. def construct(self, x):
  99. predict = self.network(x)
  100. return self.loss(predict)
  101. def compile_graph(batch_size, num_heads, dp, mp, auto=False, shard=True):
  102. if auto:
  103. context.set_auto_parallel_context(parallel_mode="auto_parallel")
  104. else:
  105. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  106. x = Tensor(np.ones((batch_size * 1024, num_heads * 128)), ms.float32)
  107. net = GradWrap(NetWithLoss(Net(batch_size, num_heads, dp, mp, shard=shard)))
  108. net.set_auto_parallel()
  109. net.set_train()
  110. _cell_graph_executor.compile(net, x)
  111. def test_dsd_matmul_model_parallel_mix():
  112. context.set_auto_parallel_context(device_num=16, global_rank=0)
  113. batch_size = 128
  114. num_heads = 32
  115. dp = 2
  116. mp = 8
  117. compile_graph(batch_size, num_heads, dp, mp)
  118. def test_dsd_matmul_model_parallel_dp():
  119. context.set_auto_parallel_context(device_num=16, global_rank=0)
  120. batch_size = 128
  121. num_heads = 32
  122. dp = 16
  123. mp = 1
  124. compile_graph(batch_size, num_heads, dp, mp)
  125. def test_dsd_matmul_model_parallel_mp():
  126. context.set_auto_parallel_context(device_num=16, global_rank=0)
  127. batch_size = 128
  128. num_heads = 32
  129. dp = 1
  130. mp = 16
  131. compile_graph(batch_size, num_heads, dp, mp)
  132. def test_dsd_matmul_model_parallel_mix_auto():
  133. set_algo_parameters(fully_use_devices=False)
  134. context.set_auto_parallel_context(device_num=16, global_rank=0)
  135. batch_size = 128
  136. num_heads = 32
  137. dp = 2
  138. mp = 8
  139. compile_graph(batch_size, num_heads, dp, mp, auto=True)
  140. def test_dsd_matmul_model_parallel_dp_auto():
  141. context.set_auto_parallel_context(device_num=16, global_rank=0)
  142. batch_size = 128
  143. num_heads = 32
  144. dp = 16
  145. mp = 1
  146. compile_graph(batch_size, num_heads, dp, mp, auto=True)
  147. def test_dsd_matmul_model_parallel_mp_auto():
  148. context.set_auto_parallel_context(device_num=16, global_rank=0)
  149. batch_size = 128
  150. num_heads = 32
  151. dp = 1
  152. mp = 16
  153. compile_graph(batch_size, num_heads, dp, mp, auto=True)
  154. def test_dsd_matmul_model_parallel_auto():
  155. set_algo_parameters(fully_use_devices=False)
  156. context.set_auto_parallel_context(device_num=16, global_rank=0)
  157. batch_size = 128
  158. num_heads = 32
  159. dp = 1
  160. mp = 16
  161. compile_graph(batch_size, num_heads, dp, mp, auto=True, shard=False)