You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_pipeline_split.py 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore as ms
  17. import mindspore.nn as nn
  18. from mindspore import context
  19. from mindspore import Tensor
  20. from mindspore.ops import operations as P
  21. from mindspore.common.parameter import Parameter
  22. from mindspore.common.initializer import initializer
  23. from mindspore.train.model import Model
  24. class DatasetLenet():
  25. def __init__(self, data, label, length=3):
  26. self.data = data
  27. self.label = label
  28. self.index = 1
  29. self.length = length
  30. def __iter__(self):
  31. return self
  32. def __next__(self):
  33. if self.index >= self.length:
  34. raise StopIteration
  35. self.index += 1
  36. return self.data, self.label
  37. def reset(self):
  38. self.index = 0
  39. def get_dataset_size(self):
  40. return 32
  41. def get_repeat_count(self):
  42. return 1
  43. def get_batch_size(self):
  44. return 32
  45. def create_tuple_iterator(self, num_epochs=1, do_copy=True):
  46. return self
  47. class MatMulCell(nn.Cell):
  48. def __init__(self, strategy1, strategy2, param=None):
  49. super().__init__()
  50. self.param = Parameter(initializer("zeros", [64, 64]), name="param")
  51. if param is not None:
  52. self.param = param
  53. self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
  54. self.matmul = P.MatMul().shard(strategy1)
  55. self.matmul1 = P.MatMul().shard(strategy2)
  56. def construct(self, x):
  57. out = self.matmul(x, self.param)
  58. out = self.matmul1(out, self.param1)
  59. return out
  60. class Net(nn.Cell):
  61. def __init__(self, strategy1, strategy2, param=None):
  62. super().__init__()
  63. self.block = nn.CellList()
  64. for i in range(2):
  65. cell = MatMulCell(strategy1, strategy2, param)
  66. cell.stage = i
  67. self.block.append(cell)
  68. def construct(self, x):
  69. for i in range(2):
  70. x = self.block[i](x)
  71. return x
  72. class PipelineSplit(nn.Cell):
  73. def __init__(self, strategy1, strategy2):
  74. super().__init__()
  75. self.cell = Net(strategy1, strategy2)
  76. def construct(self, x, label):
  77. x = self.cell(x)
  78. return x
  79. class PipelineSplit2(nn.Cell):
  80. def __init__(self, strategy1, strategy2):
  81. super().__init__()
  82. self.param = Parameter(initializer("zeros", [64, 64]), name="param")
  83. self.cell = Net(strategy1, strategy2, self.param)
  84. def construct(self, x, label):
  85. x = self.cell(x)
  86. return x
  87. def test_pipeline_split_stage0():
  88. context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2)
  89. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  90. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  91. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  92. strategy1 = ((4, 1), (1, 1))
  93. strategy2 = ((2, 1), (1, 1))
  94. net = PipelineSplit(strategy1, strategy2)
  95. params = net.cell.block[0].trainable_params()
  96. dataset = DatasetLenet(data, label, 3)
  97. optimizer = nn.Lamb(params, learning_rate=0.01)
  98. model = Model(net, optimizer=optimizer)
  99. model.train(2, dataset, dataset_sink_mode=False)
  100. def test_pipeline_split_stage1():
  101. context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
  102. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  103. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  104. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  105. strategy1 = ((4, 1), (1, 1))
  106. strategy2 = ((2, 1), (1, 1))
  107. net = PipelineSplit(strategy1, strategy2)
  108. params = net.cell.block[1].trainable_params()
  109. dataset = DatasetLenet(data, label, 3)
  110. optimizer = nn.Lamb(params, learning_rate=0.01)
  111. model = Model(net, optimizer=optimizer)
  112. model.train(2, dataset, dataset_sink_mode=False)
  113. def test_pipeline_split_shared_parameter_stage0():
  114. context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2)
  115. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  116. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  117. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  118. strategy1 = ((4, 1), (1, 1))
  119. strategy2 = ((2, 1), (1, 1))
  120. net = PipelineSplit2(strategy1, strategy2)
  121. params = net.cell.block[0].trainable_params()
  122. dataset = DatasetLenet(data, label, 3)
  123. optimizer = nn.Lamb(params, learning_rate=0.01)
  124. model = Model(net, optimizer=optimizer)
  125. model.train(2, dataset, dataset_sink_mode=False)
  126. def test_pipeline_split_shared_parameter_stage1():
  127. context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
  128. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  129. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  130. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  131. strategy1 = ((4, 1), (1, 1))
  132. strategy2 = ((2, 1), (1, 1))
  133. net = PipelineSplit2(strategy1, strategy2)
  134. params = net.cell.block[1].trainable_params()
  135. dataset = DatasetLenet(data, label, 3)
  136. optimizer = nn.Lamb(params, learning_rate=0.01)
  137. model = Model(net, optimizer=optimizer)
  138. model.train(2, dataset, dataset_sink_mode=False)