You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_virtual_output.py 9.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import re
  16. import numpy as np
  17. import mindspore as ms
  18. import mindspore.nn as nn
  19. from mindspore import Tensor
  20. from mindspore import context
  21. from mindspore.common.api import _executor
  22. from mindspore.ops import operations as P
  23. from mindspore.common.parameter import Parameter
  24. context.set_context(mode=context.GRAPH_MODE)
  25. class DenseMutMulNet(nn.Cell):
  26. def __init__(self):
  27. super(DenseMutMulNet, self).__init__()
  28. self.fc1 = nn.Dense(128, 768)
  29. self.fc2 = nn.Dense(128, 768)
  30. self.fc3 = nn.Dense(128, 768)
  31. self.fc4 = nn.Dense(768, 768, has_bias=False)
  32. self.relu4 = nn.ReLU()
  33. self.relu5 = nn.ReLU()
  34. self.transpose = P.Transpose()
  35. self.matmul1 = P.MatMul()
  36. self.matmul2 = P.MatMul()
  37. self.fc4.matmul.shard(((1, 1), (8, 1)))
  38. def construct(self, x):
  39. q = self.fc1(x)
  40. k = self.fc2(x)
  41. v = self.fc3(x)
  42. k = self.transpose(k, (1, 0))
  43. c = self.relu4(self.matmul1(q, k))
  44. s = self.relu5(self.matmul2(c, v))
  45. s = self.fc4(s)
  46. return s
  47. class MulNegTwoOutputNet(nn.Cell):
  48. def __init__(self):
  49. super().__init__()
  50. self.mul = P.Mul().shard(((2, 4), (2, 4)))
  51. self.neg = P.Neg().shard(((2, 4),))
  52. self.mul_weight = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight")
  53. def construct(self, x):
  54. out1 = self.mul(x, self.mul_weight)
  55. out2 = self.neg(out1)
  56. return out1, out2
  57. class ReshapeMatMulNet(nn.Cell):
  58. def __init__(self, strategy1, strategy2):
  59. super().__init__()
  60. self.reshape = P.Reshape()
  61. self.matmul = P.MatMul().shard(strategy2)
  62. self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
  63. # x (64, 4, 7)
  64. def construct(self, x):
  65. out = self.reshape(x, (64, 28))
  66. out = self.matmul(out, self.matmul_weight)
  67. return out
  68. class MatMulReshapeNet(nn.Cell):
  69. def __init__(self, strategy1, strategy2):
  70. super().__init__()
  71. self.reshape = P.Reshape()
  72. self.matmul = P.MatMul().shard(strategy1)
  73. self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
  74. # x (128, 28)
  75. def construct(self, x):
  76. out = self.matmul(x, self.matmul_weight)
  77. out = self.reshape(out, (64, -1))
  78. return out
  79. class ReshapeMulNet(nn.Cell):
  80. def __init__(self):
  81. super().__init__()
  82. self.reshape = P.Reshape()
  83. self.mul = P.Mul().shard(((1, 2, 4), (2, 4)))
  84. self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
  85. def construct(self, x):
  86. weight = self.reshape(self.mul_weight, (1, 128, 96))
  87. out = self.mul(weight, self.mul_weight)
  88. return out
  89. def compile_graph(x, net):
  90. net.set_auto_parallel()
  91. net.set_train(False)
  92. _executor.compile(net, x, auto_parallel_mode=True)
  93. strategies = _executor._get_shard_strategy(net)
  94. return strategies
  95. def test_dense_relu_semi_auto():
  96. context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
  97. net = DenseMutMulNet()
  98. x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
  99. strategies = compile_graph(x, net)
  100. for (k, v) in strategies.items():
  101. if re.search('VirtualOutput-op', k) is not None:
  102. assert v[0][0] == 8
  103. def test_dense_relu_semi_auto_full_batch():
  104. context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=True)
  105. net = DenseMutMulNet()
  106. x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
  107. strategies = compile_graph(x, net)
  108. for (k, v) in strategies.items():
  109. if re.search('VirtualOutput-op', k) is not None:
  110. assert v[0][0] == 1
  111. def test_dense_relu_auto():
  112. context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False)
  113. net = DenseMutMulNet()
  114. x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
  115. strategies = compile_graph(x, net)
  116. for (k, v) in strategies.items():
  117. if re.search('VirtualOutput-op', k) is not None:
  118. assert v[0][0] == 8
  119. def test_dense_relu_auto_full_batch():
  120. context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=True)
  121. net = DenseMutMulNet()
  122. x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
  123. strategies = compile_graph(x, net)
  124. for (k, v) in strategies.items():
  125. if re.search('VirtualOutput-op', k) is not None:
  126. assert v[0][0] == 1
  127. def test_mul_neg_two_output_semi_auto():
  128. context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
  129. net = MulNegTwoOutputNet()
  130. x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
  131. strategies = compile_graph(x, net)
  132. count = 0
  133. for (k, v) in strategies.items():
  134. if re.search('VirtualOutput-op', k) is not None:
  135. count += 1
  136. assert v[0][0] == 8
  137. assert count == 2
  138. def test_mul_neg_two_output_semi_auto_full_batch():
  139. context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=True)
  140. net = MulNegTwoOutputNet()
  141. x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
  142. strategies = compile_graph(x, net)
  143. count = 0
  144. for (k, v) in strategies.items():
  145. if re.search('VirtualOutput-op', k) is not None:
  146. count += 1
  147. assert v[0][0] == 1
  148. assert count == 2
  149. def test_mul_neg_two_output_auto():
  150. context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False)
  151. net = MulNegTwoOutputNet()
  152. x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
  153. strategies = compile_graph(x, net)
  154. count = 0
  155. for (k, v) in strategies.items():
  156. if re.search('VirtualOutput-op', k) is not None:
  157. count += 1
  158. assert v[0][0] == 8
  159. assert count == 2
  160. def test_mul_neg_two_output_full_batch():
  161. context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=True)
  162. net = MulNegTwoOutputNet()
  163. x = Tensor(np.ones([32, 128]).astype(np.float32) * 0.01)
  164. strategies = compile_graph(x, net)
  165. count = 0
  166. for (k, v) in strategies.items():
  167. if re.search('VirtualOutput-op', k) is not None:
  168. count += 1
  169. assert v[0][0] == 1
  170. assert count == 2
  171. def test_reshape_matmul_semi_auto():
  172. context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
  173. strategy1 = None
  174. strategy2 = ((1, 1), (1, 8))
  175. net = ReshapeMatMulNet(strategy1, strategy2)
  176. x = Tensor(np.ones([64, 4, 7]), ms.float32)
  177. strategies = compile_graph(x, net)
  178. for (k, v) in strategies.items():
  179. if re.search('VirtualOutput-op', k) is not None:
  180. assert v[0][0] == 8
  181. def test_reshape_matmul_auto():
  182. context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False)
  183. strategy1 = None
  184. strategy2 = ((1, 1), (1, 8))
  185. net = ReshapeMatMulNet(strategy1, strategy2)
  186. x = Tensor(np.ones([64, 4, 7]), ms.float32)
  187. strategies = compile_graph(x, net)
  188. for (k, v) in strategies.items():
  189. if re.search('VirtualOutput-op', k) is not None:
  190. assert v[0][0] == 8
  191. def test_matmul_reshape_semi_auto():
  192. context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=False)
  193. strategy2 = None
  194. strategy1 = ((1, 1), (1, 8))
  195. net = MatMulReshapeNet(strategy1, strategy2)
  196. x = Tensor(np.ones([128, 28]), ms.float32)
  197. strategies = compile_graph(x, net)
  198. for (k, v) in strategies.items():
  199. if re.search('VirtualOutput-op', k) is not None:
  200. assert v[0][0] == 8
  201. def test_matmul_reshape_auto():
  202. context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=False)
  203. strategy2 = None
  204. strategy1 = ((1, 1), (1, 8))
  205. net = MatMulReshapeNet(strategy1, strategy2)
  206. x = Tensor(np.ones([128, 28]), ms.float32)
  207. strategies = compile_graph(x, net)
  208. for (k, v) in strategies.items():
  209. if re.search('VirtualOutput-op', k) is not None:
  210. assert v[0][0] == 8
  211. def test_reshape_mul_semi_auto():
  212. context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel", full_batch=True)
  213. net = ReshapeMulNet()
  214. x = Tensor(np.ones([64, 4]), ms.float32)
  215. strategies = compile_graph(x, net)
  216. for (k, v) in strategies.items():
  217. if re.search('VirtualOutput-op', k) is not None:
  218. assert v[0][0] == 1
  219. def test_reshape_mul_auto():
  220. context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="auto_parallel", full_batch=True)
  221. net = ReshapeMulNet()
  222. x = Tensor(np.ones([64, 4]), ms.float32)
  223. strategies = compile_graph(x, net)
  224. for (k, v) in strategies.items():
  225. if re.search('VirtualOutput-op', k) is not None:
  226. assert v[0][0] == 1