You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_comm_fusion.py 9.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import pytest
  17. import mindspore as ms
  18. import mindspore.nn as nn
  19. from mindspore import context
  20. from mindspore import Tensor
  21. from mindspore.ops import operations as P
  22. from mindspore.common.parameter import Parameter
  23. from mindspore.common.initializer import initializer
  24. from mindspore.train.model import Model
  25. from mindspore.nn.wrap.cell_wrapper import PipelineCell
  26. from mindspore.parallel._auto_parallel_context import auto_parallel_context
  27. from tests.ut.python.parallel.test_adafactor import compile_net
  28. from tests.ut.python.parallel.test_adafactor import Net as Net2
  29. class DatasetLenet():
  30. def __init__(self, data, label, length=3):
  31. self.data = data
  32. self.label = label
  33. self.index = 1
  34. self.length = length
  35. def __iter__(self):
  36. return self
  37. def __next__(self):
  38. if self.index >= self.length:
  39. raise StopIteration
  40. self.index += 1
  41. return self.data, self.label
  42. def reset(self):
  43. self.index = 0
  44. @staticmethod
  45. def get_dataset_size():
  46. return 32
  47. @staticmethod
  48. def get_repeat_count():
  49. return 1
  50. @staticmethod
  51. def get_batch_size():
  52. return 32
  53. def create_tuple_iterator(self, num_epochs=1, do_copy=True):
  54. return self
  55. class MatMulCell(nn.Cell):
  56. def __init__(self, strategy1, strategy2, param=None):
  57. super().__init__()
  58. self.param = Parameter(initializer("zeros", [64, 64]), name="param")
  59. if param is not None:
  60. self.param = param
  61. self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
  62. self.matmul = P.MatMul().shard(strategy1)
  63. self.matmul1 = P.MatMul().shard(strategy2)
  64. def construct(self, x):
  65. out = self.matmul(x, self.param)
  66. out = self.matmul1(out, self.param1)
  67. return out
  68. class Net(nn.Cell):
  69. def __init__(self, strategy1, strategy2, param=None):
  70. super().__init__()
  71. self.block = nn.CellList()
  72. for i in range(2):
  73. cell = MatMulCell(strategy1, strategy2, param)
  74. cell.pipeline_stage = i
  75. self.block.append(cell)
  76. def construct(self, x):
  77. for i in range(2):
  78. x = self.block[i](x)
  79. return x
  80. class PipelineSplit(nn.Cell):
  81. def __init__(self, strategy1, strategy2):
  82. super().__init__()
  83. self.cell = Net(strategy1, strategy2)
  84. self.cell.block[0].matmul.add_prim_attr("parameter_start", 0)
  85. def construct(self, x, label):
  86. x = self.cell(x)
  87. return x
  88. def test_fusion_size():
  89. """
  90. Feature: test_fusion_auto in size mode
  91. Description: allgather and reduce scatter fusion in size mode
  92. Expectation: success
  93. """
  94. allgather_threshold = 8
  95. reducescatter_threshold = 16
  96. comm_fusion_dict = {"allgather": {"mode": "size", "config": allgather_threshold},
  97. "reducescatter": {"mode": "size", "config": reducescatter_threshold}}
  98. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
  99. context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
  100. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  101. label = Tensor(np.ones([64]), dtype=ms.float32)
  102. strategy1 = ((4, 1), (1, 1))
  103. strategy2 = ((2, 1), (1, 1))
  104. net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
  105. dataset = DatasetLenet(data, label, 3)
  106. optimizer = nn.Lamb(net.trainable_params(), learning_rate=0.01)
  107. model = Model(net, optimizer=optimizer)
  108. model.train(2, dataset, dataset_sink_mode=False)
  109. assert auto_parallel_context().allgather_fusion_threshold_mb() == allgather_threshold
  110. assert auto_parallel_context().reducescatter_fusion_threshold_mb() == reducescatter_threshold
  111. def test_fusion_auto():
  112. """
  113. Feature: test_fusion_auto in auto mode
  114. Description: allgather and reduce scatter fusion in auto mode
  115. Expectation: success
  116. """
  117. comm_fusion_dict = {"allgather": {"mode": "auto", "config": None},
  118. "reducescatter": {"mode": "auto", "config": None}}
  119. context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True,
  120. parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
  121. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  122. label = Tensor(np.ones([64]), dtype=ms.float32)
  123. strategy1 = ((4, 1), (1, 1))
  124. strategy2 = ((2, 1), (1, 1))
  125. net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
  126. dataset = DatasetLenet(data, label, 3)
  127. optimizer = nn.Lamb(net.trainable_params(), learning_rate=0.01)
  128. model = Model(net, optimizer=optimizer)
  129. model.train(2, dataset, dataset_sink_mode=False)
  130. assert auto_parallel_context().allgather_fusion_threshold_mb() == 64
  131. assert auto_parallel_context().reducescatter_fusion_threshold_mb() == 64
  132. def test_fusion_optimizer_parallel():
  133. """
  134. Feature: test_fusion_optimizer_parallel in size mode
  135. Description: allgather and reduce scatter size fusion in optimizer parallel
  136. Expectation: compile success
  137. """
  138. allgather_threshold = 16
  139. reducescatter_threshold = 8
  140. comm_fusion_dict = {"allgather": {"mode": "size", "config": allgather_threshold},
  141. "reducescatter": {"mode": "size", "config": reducescatter_threshold}}
  142. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0,
  143. enable_parallel_optimizer=True, comm_fusion=comm_fusion_dict)
  144. _w0 = Tensor(np.ones([64, 16, 2]), dtype=ms.float32)
  145. _w1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
  146. _w2 = Tensor(np.ones([32]), dtype=ms.float32)
  147. strategy1 = ((4, 2), (2, 2))
  148. strategy2 = ((4, 2), (2,))
  149. net = Net2(_w0, _w1, _w2, strategy1, strategy2)
  150. compile_net(net)
  151. comm_fusion_dict = {"allgather": {"mode": "auto", "config": None},
  152. "reducescatter": {"mode": "auto", "config": None}}
  153. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0,
  154. enable_parallel_optimizer=True, comm_fusion=comm_fusion_dict)
  155. compile_net(net)
  156. def test_allgather_fusion_invalid_value_failed():
  157. """
  158. Feature: test_allgather_fusion with invalid value
  159. Description: test_allgather_fusion with invalid value
  160. Expectation: throw TypeError
  161. """
  162. with pytest.raises(TypeError):
  163. comm_fusion_dict = [1, 2]
  164. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
  165. with pytest.raises(TypeError):
  166. comm_fusion_dict = {"allgather": [1, 2]}
  167. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
  168. with pytest.raises(TypeError):
  169. comm_fusion_dict = {"allgather": {"mode": "size", "config": "30.12"}}
  170. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
  171. with pytest.raises(KeyError):
  172. comm_fusion_dict = {"all": {"mode": "size", "config": 30}}
  173. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
  174. with pytest.raises(KeyError):
  175. comm_fusion_dict = {"allgather": {"modes": "size", "config": 30}}
  176. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
  177. with pytest.raises(KeyError):
  178. comm_fusion_dict = {"allgather": {"mode": "sizes", "config": 30}}
  179. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
  180. with pytest.raises(KeyError):
  181. comm_fusion_dict = {"allgather": {"mode": "size"}}
  182. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
  183. def test_reducescatter_fusion_invalid_value_failed():
  184. """
  185. Feature: test_reducescatter_fusion with invalid value
  186. Description: test_reducescatter_fusion with invalid value
  187. Expectation: throw TypeError
  188. """
  189. with pytest.raises(TypeError):
  190. comm_fusion_dict = {"reducescatter": [1, 2]}
  191. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
  192. with pytest.raises(TypeError):
  193. comm_fusion_dict = {"reducescatter": {"mode": "size", "config": "30.12"}}
  194. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
  195. with pytest.raises(KeyError):
  196. comm_fusion_dict = {"reducescatter": {"modes": "size", "config": 30}}
  197. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
  198. with pytest.raises(KeyError):
  199. comm_fusion_dict = {"reducescatter": {"mode": "sizes", "config": 30}}
  200. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)
  201. with pytest.raises(KeyError):
  202. comm_fusion_dict = {"reducescatter": {"mode": "size"}}
  203. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", comm_fusion=comm_fusion_dict)