You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_pipeline_split.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import numpy as np
  16. import mindspore as ms
  17. import mindspore.nn as nn
  18. from mindspore import context
  19. from mindspore import Tensor
  20. from mindspore.ops import operations as P
  21. from mindspore.common.parameter import Parameter
  22. from mindspore.common.initializer import initializer
  23. from mindspore.train.model import Model
  24. from mindspore.nn.wrap.cell_wrapper import PipelineCell
  25. class DatasetLenet():
  26. def __init__(self, data, label, length=3):
  27. self.data = data
  28. self.label = label
  29. self.index = 1
  30. self.length = length
  31. def __iter__(self):
  32. return self
  33. def __next__(self):
  34. if self.index >= self.length:
  35. raise StopIteration
  36. self.index += 1
  37. return self.data, self.label
  38. def reset(self):
  39. self.index = 0
  40. def get_dataset_size(self):
  41. return 32
  42. def get_repeat_count(self):
  43. return 1
  44. def get_batch_size(self):
  45. return 32
  46. def create_tuple_iterator(self, num_epochs=1, do_copy=True):
  47. return self
  48. class MatMulCell(nn.Cell):
  49. def __init__(self, strategy1, strategy2, param=None):
  50. super().__init__()
  51. self.param = Parameter(initializer("zeros", [64, 64]), name="param")
  52. if param is not None:
  53. self.param = param
  54. self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
  55. self.matmul = P.MatMul().shard(strategy1)
  56. self.matmul1 = P.MatMul().shard(strategy2)
  57. def construct(self, x):
  58. out = self.matmul(x, self.param)
  59. out = self.matmul1(out, self.param1)
  60. return out
  61. class Net(nn.Cell):
  62. def __init__(self, strategy1, strategy2, param=None):
  63. super().__init__()
  64. self.block = nn.CellList()
  65. for i in range(2):
  66. cell = MatMulCell(strategy1, strategy2, param)
  67. cell.pipeline_stage = i
  68. self.block.append(cell)
  69. def construct(self, x):
  70. for i in range(2):
  71. x = self.block[i](x)
  72. return x
  73. class PipelineSplit(nn.Cell):
  74. def __init__(self, strategy1, strategy2):
  75. super().__init__()
  76. self.cell = Net(strategy1, strategy2)
  77. self.cell.block[0].matmul.add_prim_attr("parameter_start", 0)
  78. def construct(self, x, label):
  79. x = self.cell(x)
  80. return x
  81. class PipelineSplit2(nn.Cell):
  82. def __init__(self, strategy1, strategy2):
  83. super().__init__()
  84. self.param = Parameter(initializer("zeros", [64, 64]), name="param")
  85. self.cell = Net(strategy1, strategy2, self.param)
  86. self.cell.block[0].matmul.add_prim_attr("parameter_start", 0)
  87. def construct(self, x, label):
  88. x = self.cell(x)
  89. return x
  90. def test_pipeline_split_stage0():
  91. context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2)
  92. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  93. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  94. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  95. strategy1 = ((4, 1), (1, 1))
  96. strategy2 = ((2, 1), (1, 1))
  97. net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
  98. params = net.network.cell.block[0].trainable_params()
  99. dataset = DatasetLenet(data, label, 3)
  100. optimizer = nn.Lamb(params, learning_rate=0.01)
  101. model = Model(net, optimizer=optimizer)
  102. model.train(2, dataset, dataset_sink_mode=False)
  103. for _, param in model._train_network.parameters_and_names():
  104. assert param.name != "cell.block.1.param"
  105. assert param.name != "cell.block.1.param1"
  106. def test_pipeline_split_stage1():
  107. context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
  108. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  109. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  110. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  111. strategy1 = ((4, 1), (1, 1))
  112. strategy2 = ((2, 1), (1, 1))
  113. net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
  114. params = net.network.cell.block[1].trainable_params()
  115. dataset = DatasetLenet(data, label, 3)
  116. optimizer = nn.Lamb(params, learning_rate=0.01)
  117. model = Model(net, optimizer=optimizer)
  118. model.train(2, dataset, dataset_sink_mode=False)
  119. for _, param in model._train_network.parameters_and_names():
  120. assert param.name != "cell.block.0.param"
  121. assert param.name != "cell.block.0.param1"
  122. def test_pipeline_split_shared_parameter_stage0():
  123. context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2)
  124. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  125. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  126. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  127. strategy1 = ((4, 1), (1, 1))
  128. strategy2 = ((2, 1), (1, 1))
  129. net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4)
  130. params = net.network.cell.block[0].trainable_params()
  131. dataset = DatasetLenet(data, label, 3)
  132. optimizer = nn.Lamb(params, learning_rate=0.01)
  133. model = Model(net, optimizer=optimizer)
  134. model.train(2, dataset, dataset_sink_mode=False)
  135. def test_pipeline_split_shared_parameter_stage1():
  136. context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
  137. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  138. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  139. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  140. strategy1 = ((4, 1), (1, 1))
  141. strategy2 = ((2, 1), (1, 1))
  142. net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4)
  143. params = net.network.cell.block[1].trainable_params()
  144. dataset = DatasetLenet(data, label, 3)
  145. optimizer = nn.Lamb(params, learning_rate=0.01)
  146. model = Model(net, optimizer=optimizer)
  147. model.train(2, dataset, dataset_sink_mode=False)
  148. def test_pipeline_split_shared_parameter_stage0_predict():
  149. context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, full_batch=True)
  150. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  151. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  152. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  153. strategy1 = ((4, 1), (1, 1))
  154. strategy2 = ((2, 1), (1, 1))
  155. net = PipelineSplit2(strategy1, strategy2)
  156. model = Model(net)
  157. model.predict(data, label)
  158. def test_pipeline_split_shared_parameter_stage1_predict():
  159. context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, full_batch=True)
  160. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  161. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  162. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  163. strategy1 = ((4, 1), (1, 1))
  164. strategy2 = ((2, 1), (1, 1))
  165. net = PipelineSplit2(strategy1, strategy2)
  166. model = Model(net)
  167. model.predict(data, label)
  168. def test_pipeline_split_stage0_opt_shard():
  169. context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
  170. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  171. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  172. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  173. strategy1 = ((4, 1), (1, 1))
  174. strategy2 = ((2, 1), (1, 1))
  175. net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
  176. params = net.network.cell.block[0].trainable_params()
  177. dataset = DatasetLenet(data, label, 3)
  178. optimizer = nn.Lamb(params, learning_rate=0.01)
  179. model = Model(net, optimizer=optimizer)
  180. model.train(2, dataset, dataset_sink_mode=False)
  181. for _, param in model._train_network.parameters_and_names():
  182. assert param.name != "cell.block.1.param"
  183. assert param.name != "cell.block.1.param1"
  184. def test_pipeline_split_stage1_opt_shard():
  185. context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, enable_parallel_optimizer=True)
  186. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  187. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  188. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  189. strategy1 = ((4, 1), (1, 1))
  190. strategy2 = ((2, 1), (1, 1))
  191. net = PipelineCell(PipelineSplit(strategy1, strategy2), 4)
  192. params = net.network.cell.block[1].trainable_params()
  193. dataset = DatasetLenet(data, label, 3)
  194. optimizer = nn.Lamb(params, learning_rate=0.01)
  195. model = Model(net, optimizer=optimizer)
  196. model.train(2, dataset, dataset_sink_mode=False)
  197. for _, param in model._train_network.parameters_and_names():
  198. assert param.name != "cell.block.0.param"
  199. assert param.name != "cell.block.0.param1"
  200. def test_pipeline_split_shared_parameter_stage0_opt_shard():
  201. context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2, enable_parallel_optimizer=True)
  202. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  203. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  204. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  205. strategy1 = ((4, 1), (1, 1))
  206. strategy2 = ((2, 1), (1, 1))
  207. net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4)
  208. params = net.network.cell.block[0].trainable_params()
  209. dataset = DatasetLenet(data, label, 3)
  210. optimizer = nn.Lamb(params, learning_rate=0.01)
  211. model = Model(net, optimizer=optimizer)
  212. model.train(2, dataset, dataset_sink_mode=False)
  213. def test_pipeline_split_shared_parameter_stage1_opt_shard():
  214. context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2, enable_parallel_optimizer=True)
  215. context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
  216. data = Tensor(np.ones([32, 64]), dtype=ms.float32)
  217. label = Tensor(np.ones([64, 64]), dtype=ms.float32)
  218. strategy1 = ((4, 1), (1, 1))
  219. strategy2 = ((2, 1), (1, 1))
  220. net = PipelineCell(PipelineSplit2(strategy1, strategy2), 4)
  221. params = net.network.cell.block[1].trainable_params()
  222. dataset = DatasetLenet(data, label, 3)
  223. optimizer = nn.Lamb(params, learning_rate=0.01)
  224. model = Model(net, optimizer=optimizer)
  225. model.train(2, dataset, dataset_sink_mode=False)