You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_pipeline_split.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore as ms
  17. import mindspore.nn as nn
  18. from mindspore import context
  19. from mindspore import Tensor
  20. from mindspore.ops import operations as P
  21. from mindspore.common.parameter import Parameter
  22. from mindspore.common.initializer import initializer
  23. from mindspore.train.model import Model
  24. class DatasetLenet():
  25. def __init__(self, data, label, length=3):
  26. self.data = data
  27. self.label = label
  28. self.index = 1
  29. self.length = length
  30. def __iter__(self):
  31. return self
  32. def __next__(self):
  33. if self.index >= self.length:
  34. raise StopIteration
  35. self.index += 1
  36. return self.data, self.label
  37. def reset(self):
  38. self.index = 0
  39. def get_dataset_size(self):
  40. return 32
  41. def get_repeat_count(self):
  42. return 1
  43. def get_batch_size(self):
  44. return 32
  45. def create_tuple_iterator(self, num_epochs=1, do_copy=True):
  46. return self
  47. class MatMulCell(nn.Cell):
  48. def __init__(self, strategy1, strategy2, param=None):
  49. super().__init__()
  50. self.param = Parameter(initializer("zeros", [64, 64]), name="param")
  51. if param is not None:
  52. self.param = param
  53. self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
  54. self.matmul = P.MatMul().shard(strategy1)
  55. self.matmul1 = P.MatMul().shard(strategy2)
  56. def construct(self, x):
  57. out = self.matmul(x, self.param)
  58. out = self.matmul1(out, self.param1)
  59. return out
  60. class Net(nn.Cell):
  61. def __init__(self, strategy1, strategy2, param=None):
  62. super().__init__()
  63. self.block = nn.CellList()
  64. for i in range(2):
  65. cell = MatMulCell(strategy1, strategy2, param)
  66. cell.pipeline_stage = i
  67. self.block.append(cell)
  68. def construct(self, x):
  69. for i in range(2):
  70. x = self.block[i](x)
  71. return x
  72. class PipelineSplit(nn.Cell):
  73. def __init__(self, strategy1, strategy2):
  74. super().__init__()
  75. self.cell = Net(strategy1, strategy2)
  76. def construct(self, x, label):
  77. x = self.cell(x)
  78. return x
  79. class PipelineSplit2(nn.Cell):
  80. def __init__(self, strategy1, strategy2):
  81. super().__init__()
  82. self.param = Parameter(initializer("zeros", [64, 64]), name="param")
  83. self.cell = Net(strategy1, strategy2, self.param)
  84. def construct(self, x, label):
  85. x = self.cell(x)
  86. return x
  87. def test_pipeline_split_stage0():
  88. context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2)
  89. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  90. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  91. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  92. strategy1 = ((4, 1), (1, 1))
  93. strategy2 = ((2, 1), (1, 1))
  94. net = PipelineSplit(strategy1, strategy2)
  95. params = net.cell.block[0].trainable_params()
  96. dataset = DatasetLenet(data, label, 3)
  97. optimizer = nn.Lamb(params, learning_rate=0.01)
  98. model = Model(net, optimizer=optimizer)
  99. model.train(2, dataset, dataset_sink_mode=False)
  100. for _, param in model._train_network.parameters_and_names():
  101. assert param.name != "cell.block.1.param"
  102. assert param.name != "cell.block.1.param1"
  103. def test_pipeline_split_stage1():
  104. context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
  105. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  106. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  107. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  108. strategy1 = ((4, 1), (1, 1))
  109. strategy2 = ((2, 1), (1, 1))
  110. net = PipelineSplit(strategy1, strategy2)
  111. params = net.cell.block[1].trainable_params()
  112. dataset = DatasetLenet(data, label, 3)
  113. optimizer = nn.Lamb(params, learning_rate=0.01)
  114. model = Model(net, optimizer=optimizer)
  115. model.train(2, dataset, dataset_sink_mode=False)
  116. for _, param in model._train_network.parameters_and_names():
  117. assert param.name != "cell.block.0.param"
  118. assert param.name != "cell.block.0.param1"
  119. def test_pipeline_split_shared_parameter_stage0():
  120. context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2)
  121. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  122. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  123. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  124. strategy1 = ((4, 1), (1, 1))
  125. strategy2 = ((2, 1), (1, 1))
  126. net = PipelineSplit2(strategy1, strategy2)
  127. params = net.cell.block[0].trainable_params()
  128. dataset = DatasetLenet(data, label, 3)
  129. optimizer = nn.Lamb(params, learning_rate=0.01)
  130. model = Model(net, optimizer=optimizer)
  131. model.train(2, dataset, dataset_sink_mode=False)
  132. def test_pipeline_split_shared_parameter_stage1():
  133. context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
  134. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  135. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  136. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  137. strategy1 = ((4, 1), (1, 1))
  138. strategy2 = ((2, 1), (1, 1))
  139. net = PipelineSplit2(strategy1, strategy2)
  140. params = net.cell.block[1].trainable_params()
  141. dataset = DatasetLenet(data, label, 3)
  142. optimizer = nn.Lamb(params, learning_rate=0.01)
  143. model = Model(net, optimizer=optimizer)
  144. model.train(2, dataset, dataset_sink_mode=False)