You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_parallel_transformer.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import pytest
  16. import mindspore.common.dtype as mstype
  17. import mindspore.nn as nn
  18. from mindspore import Tensor
  19. from mindspore.context import set_auto_parallel_context, ParallelMode
  20. from mindspore.ops import composite as C
  21. from mindspore.ops import functional as F
  22. import mindspore.ops as P
  23. from mindspore.parallel.nn import TransformerEncoder, TransformerDecoder, Transformer, TransformerOpParallelConfig, \
  24. VocabEmbedding, CrossEntropyLoss, OpParallelConfig, EmbeddingOpParallelConfig
  25. from mindspore.nn import Dense as Linear
  26. from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
  27. from mindspore.nn.optim import AdamWeightDecay
  28. from mindspore.nn.wrap.cell_wrapper import PipelineCell, _VirtualDatasetCell, TrainOneStepCell
  29. from mindspore.nn.wrap.loss_scale import _TrainPipelineWithLossScaleCell
  30. from mindspore.train import Model
  31. from tests.dataset_mock import MindData
  32. from tests.ut.python.ops.test_math_ops import VirtualLoss
  33. grad_all = C.GradOperation(get_all=True)
  34. class Dataset(MindData):
  35. def __init__(self, *inputs, length=3):
  36. super(Dataset, self).__init__(size=length)
  37. self.inputs = inputs
  38. self.index = 0
  39. self.length = length
  40. def __iter__(self):
  41. return self
  42. def __next__(self):
  43. if self.index >= self.length:
  44. raise StopIteration
  45. self.index += 1
  46. return self.inputs
  47. def reset(self):
  48. self.index = 0
  49. config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
  50. pipeline_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, pipeline_stage=4,
  51. micro_batch_num=4, vocab_emb_dp=False)
  52. class NetWithLossFiveInputs(nn.Cell):
  53. def __init__(self, network):
  54. super(NetWithLossFiveInputs, self).__init__()
  55. self.loss = VirtualLoss()
  56. self.network = network
  57. def construct(self, x1, x2, x3, x4, x5):
  58. predict, _, _ = self.network(x1, x2, x3, x4, x5)
  59. return self.loss(predict)
  60. def run_total_transformer_model_head(e_layer,
  61. d_layer,
  62. arg_parallel_config):
  63. dp = arg_parallel_config.data_parallel
  64. mp = arg_parallel_config.model_parallel
  65. pp = arg_parallel_config.pipeline_stage
  66. if dp * mp * pp != 1:
  67. set_auto_parallel_context(device_num=8,
  68. full_batch=True,
  69. global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  70. class Net(nn.Cell):
  71. def __init__(self, en_layer, de_layer, parallel_config):
  72. super(Net, self).__init__()
  73. self.embedding = VocabEmbedding(vocab_size=240, embedding_size=20,
  74. parallel_config=config.embedding_dp_mp_config)
  75. self.network = Transformer(encoder_layers=en_layer,
  76. decoder_layers=de_layer,
  77. batch_size=2,
  78. src_seq_length=20,
  79. tgt_seq_length=10,
  80. hidden_size=64,
  81. num_heads=8,
  82. ffn_hidden_size=64,
  83. parallel_config=parallel_config)
  84. self.head = Linear(in_channels=64, out_channels=200)
  85. self.loss = CrossEntropyLoss(parallel_config=config.dp_mp_config)
  86. def construct(self, x1, x2, x3, x4, x5, y, mask):
  87. predict, _, _ = self.network(x1, x2, x3, x4, x5)
  88. predict = P.Reshape()(predict, (-1, F.shape(predict)[-1]))
  89. return self.loss(predict, y, mask)
  90. encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
  91. encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
  92. decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
  93. decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
  94. memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
  95. seq = 20
  96. if d_layer > 0:
  97. seq = 10
  98. label = Tensor(np.ones((2 * seq,)), mstype.int32)
  99. input_mask = Tensor(np.ones((2 * seq,)), mstype.float32)
  100. net = Net(en_layer=e_layer, de_layer=d_layer, parallel_config=arg_parallel_config)
  101. params = net.trainable_params()
  102. optimizer = AdamWeightDecay(params)
  103. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  104. memory_mask, label, input_mask)
  105. net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
  106. model = Model(net_with_grad)
  107. model.train(1, dataset, dataset_sink_mode=False)
  108. def test_transformer_model():
  109. set_auto_parallel_context(device_num=8, global_rank=0,
  110. full_batch=True,
  111. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  112. net = Transformer(encoder_layers=1,
  113. decoder_layers=2,
  114. batch_size=2,
  115. src_seq_length=20,
  116. tgt_seq_length=10,
  117. hidden_size=64,
  118. num_heads=8,
  119. ffn_hidden_size=64,
  120. parallel_config=config)
  121. encoder_input_value = Tensor(np.ones((2, 20, 64)), mstype.float32)
  122. encoder_input_mask = Tensor(np.ones((2, 20, 20)), mstype.float16)
  123. decoder_input_value = Tensor(np.ones((2, 10, 64)), mstype.float32)
  124. decoder_input_mask = Tensor(np.ones((2, 10, 10)), mstype.float16)
  125. memory_mask = Tensor(np.ones((2, 10, 20)), mstype.float16)
  126. net = NetWithLossFiveInputs(net)
  127. params = net.trainable_params()
  128. optimizer = AdamWeightDecay(params)
  129. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  130. memory_mask)
  131. net_with_grad = TrainOneStepCell(net, optimizer=optimizer)
  132. model = Model(net_with_grad)
  133. model.train(1, dataset, dataset_sink_mode=False)
  134. def test_transformer_model_head_parallel_only_encoder():
  135. local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
  136. run_total_transformer_model_head(e_layer=2, d_layer=0, arg_parallel_config=local_config)
  137. def test_transformer_model_head_parallel():
  138. local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
  139. run_total_transformer_model_head(e_layer=1, d_layer=1, arg_parallel_config=local_config)
  140. def test_transformer_model_head_parallel_decoder():
  141. local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8)
  142. with pytest.raises(ValueError):
  143. run_total_transformer_model_head(e_layer=0, d_layer=1, arg_parallel_config=local_config)
  144. def test_transformer_model_head_stand_alone():
  145. local_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=1)
  146. run_total_transformer_model_head(e_layer=2, d_layer=2, arg_parallel_config=local_config)
  147. def test_pipeline_single_transformer():
  148. set_auto_parallel_context(device_num=32,
  149. full_batch=True,
  150. pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
  151. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  152. net = Transformer(batch_size=4 // pipeline_config.micro_batch_num,
  153. src_seq_length=20,
  154. tgt_seq_length=10,
  155. encoder_layers=2,
  156. decoder_layers=2,
  157. hidden_size=64,
  158. num_heads=8,
  159. ffn_hidden_size=64,
  160. parallel_config=pipeline_config)
  161. encoder_input_value = Tensor(np.ones((4, 20, 64)), mstype.float32)
  162. encoder_input_mask = Tensor(np.ones((4, 20, 20)), mstype.float16)
  163. decoder_input_value = Tensor(np.ones((4, 10, 64)), mstype.float32)
  164. decoder_input_mask = Tensor(np.ones((4, 10, 10)), mstype.float16)
  165. memory_mask = Tensor(np.ones((4, 10, 20)), mstype.float16)
  166. net = NetWithLossFiveInputs(net)
  167. net = PipelineCell(net, pipeline_config.micro_batch_num)
  168. net = _VirtualDatasetCell(net)
  169. params = net.infer_param_pipeline_stage()
  170. optimizer = AdamWeightDecay(params)
  171. dataset = Dataset(encoder_input_value, encoder_input_mask, decoder_input_value, decoder_input_mask,
  172. memory_mask)
  173. update_cell = DynamicLossScaleUpdateCell(loss_scale_value=1024, scale_factor=2, scale_window=1000)
  174. net_with_grad = _TrainPipelineWithLossScaleCell(net, optimizer=optimizer,
  175. scale_sense=update_cell)
  176. model = Model(net_with_grad)
  177. model.train(1, dataset, dataset_sink_mode=False)
  178. def test_transformer_wrong_head():
  179. set_auto_parallel_context(device_num=32,
  180. full_batch=True,
  181. pipeline_stages=pipeline_config.pipeline_stage, global_rank=0,
  182. parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  183. error_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=8, vocab_emb_dp=False)
  184. with pytest.raises(ValueError):
  185. net = Transformer(batch_size=4,
  186. src_seq_length=20,
  187. tgt_seq_length=10,
  188. encoder_layers=2,
  189. decoder_layers=2,
  190. hidden_size=64,
  191. num_heads=7,
  192. ffn_hidden_size=64,
  193. parallel_config=error_test_config)
  194. with pytest.raises(ValueError):
  195. net = Transformer(batch_size=4,
  196. src_seq_length=20,
  197. tgt_seq_length=10,
  198. encoder_layers=2,
  199. decoder_layers=2,
  200. hidden_size=63,
  201. num_heads=7,
  202. ffn_hidden_size=64,
  203. parallel_config=error_test_config)
  204. del net
  205. def test_encoder():
  206. class NetWithLoss(nn.Cell):
  207. def __init__(self, network):
  208. super(NetWithLoss, self).__init__()
  209. self.loss = VirtualLoss()
  210. self.network = network
  211. def construct(self, x1, x2):
  212. predict, _ = self.network(x1, x2)
  213. return self.loss(predict)
  214. set_auto_parallel_context(device_num=8,
  215. full_batch=True,
  216. global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  217. net = TransformerEncoder(num_layers=2,
  218. batch_size=2,
  219. seq_length=16,
  220. hidden_size=8,
  221. ffn_hidden_size=64,
  222. num_heads=8,
  223. parallel_config=config)
  224. encoder_input_value = Tensor(np.ones((2, 16, 8)), mstype.float32)
  225. encoder_input_mask = Tensor(np.ones((2, 16, 16)), mstype.float16)
  226. net = NetWithLoss(net)
  227. dataset = Dataset(encoder_input_value, encoder_input_mask)
  228. model = Model(net)
  229. model.train(1, dataset, dataset_sink_mode=False)
  230. def test_decoder():
  231. class NetWithLoss(nn.Cell):
  232. def __init__(self, network):
  233. super(NetWithLoss, self).__init__()
  234. self.loss = VirtualLoss()
  235. self.network = network
  236. def construct(self, x1, x2, x3, x4):
  237. predict, _, _ = self.network(x1, x2, x3, x4)
  238. return self.loss(predict)
  239. set_auto_parallel_context(device_num=8,
  240. full_batch=True,
  241. global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  242. net = TransformerDecoder(num_layers=1,
  243. batch_size=8,
  244. hidden_size=16,
  245. ffn_hidden_size=8,
  246. num_heads=8,
  247. src_seq_length=20,
  248. tgt_seq_length=10,
  249. parallel_config=config)
  250. encoder_input_value = Tensor(np.ones((8, 20, 16)), mstype.float32)
  251. decoder_input_value = Tensor(np.ones((8, 10, 16)), mstype.float32)
  252. decoder_input_mask = Tensor(np.ones((8, 10, 10)), mstype.float16)
  253. memory_mask = Tensor(np.ones((8, 10, 20)), mstype.float16)
  254. net = NetWithLoss(net)
  255. dataset = Dataset(decoder_input_value, decoder_input_mask, encoder_input_value, memory_mask)
  256. model = Model(net)
  257. model.train(1, dataset, dataset_sink_mode=False)
  258. def test_vocabembedding_dp_true():
  259. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  260. class NetWithLoss(nn.Cell):
  261. def __init__(self, network):
  262. super(NetWithLoss, self).__init__()
  263. self.loss = VirtualLoss()
  264. self.network = network
  265. def construct(self, x1):
  266. predict, _ = self.network(x1)
  267. return self.loss(predict)
  268. net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
  269. net = NetWithLoss(net)
  270. encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
  271. dataset = Dataset(encoder_input_value)
  272. model = Model(net)
  273. model.train(1, dataset, dataset_sink_mode=False)
  274. def test_vocabembedding_dp_false():
  275. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  276. class NetWithLoss(nn.Cell):
  277. def __init__(self, network):
  278. super(NetWithLoss, self).__init__()
  279. self.loss = VirtualLoss()
  280. self.network = network
  281. def construct(self, x1):
  282. predict, _ = self.network(x1)
  283. return self.loss(predict)
  284. net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
  285. net = NetWithLoss(net)
  286. encoder_input_value = Tensor(np.ones((2, 64)), mstype.int32)
  287. dataset = Dataset(encoder_input_value)
  288. model = Model(net)
  289. model.train(1, dataset, dataset_sink_mode=False)
  290. def test_parallel_cross_entroy_loss_semi_auto_parallel():
  291. set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL)
  292. class NetWithLoss(nn.Cell):
  293. def __init__(self, network, config_setting):
  294. super(NetWithLoss, self).__init__()
  295. self.loss = CrossEntropyLoss(config_setting)
  296. self.network = network
  297. def construct(self, x1, x2, x3):
  298. predict, _ = self.network(x1)
  299. predict = P.Reshape()(predict, (-1, 16))
  300. return self.loss(predict, x2, x3)
  301. net = VocabEmbedding(vocab_size=160, embedding_size=16, parallel_config=config.embedding_dp_mp_config)
  302. net = NetWithLoss(net, config.dp_mp_config)
  303. embed_ids = Tensor(np.ones((2, 64)), mstype.int32)
  304. labels = Tensor(np.ones((2 * 64,)), mstype.int32)
  305. input_mask = Tensor(np.ones((2 * 64,)), mstype.float32)
  306. dataset = Dataset(embed_ids, labels, input_mask)
  307. model = Model(net)
  308. model.train(1, dataset, dataset_sink_mode=False)
  309. def test_transformer_parallel_config():
  310. parallel_test_config = TransformerOpParallelConfig(data_parallel=1, model_parallel=3)
  311. with pytest.raises(TypeError):
  312. parallel_test_config.data_parallel = False
  313. with pytest.raises(ValueError):
  314. parallel_test_config.data_parallel = 0
  315. with pytest.raises(TypeError):
  316. parallel_test_config.model_parallel = False
  317. with pytest.raises(ValueError):
  318. parallel_test_config.model_parallel = 0
  319. with pytest.raises(TypeError):
  320. parallel_test_config.pipeline_stage = False
  321. with pytest.raises(ValueError):
  322. parallel_test_config.pipeline_stage = 0
  323. with pytest.raises(TypeError):
  324. parallel_test_config.micro_batch_num = False
  325. with pytest.raises(ValueError):
  326. parallel_test_config.micro_batch_num = 0
  327. with pytest.raises(TypeError):
  328. parallel_test_config.gradient_aggregation_group = False
  329. with pytest.raises(ValueError):
  330. parallel_test_config.gradient_aggregation_group = 0
  331. with pytest.raises(TypeError):
  332. parallel_test_config.recompute = 1
  333. parallel_test_config.recompute = False
  334. assert not parallel_test_config.recompute
  335. def test_parallel_config():
  336. parallel_test_config = OpParallelConfig(data_parallel=1, model_parallel=3)
  337. with pytest.raises(ValueError):
  338. parallel_test_config.data_parallel = 0
  339. with pytest.raises(TypeError):
  340. parallel_test_config.model_parallel = False
  341. with pytest.raises(ValueError):
  342. parallel_test_config.model_parallel = 0
  343. assert parallel_test_config.model_parallel == 3
  344. def test_embedding_parallel_config():
  345. parallel_test_config = EmbeddingOpParallelConfig(data_parallel=1, model_parallel=3, vocab_emb_dp=False)
  346. with pytest.raises(ValueError):
  347. parallel_test_config.data_parallel = 0
  348. with pytest.raises(TypeError):
  349. parallel_test_config.model_parallel = False
  350. with pytest.raises(ValueError):
  351. parallel_test_config.model_parallel = 0
  352. with pytest.raises(TypeError):
  353. parallel_test_config.vocab_emb_dp = 0
  354. assert not parallel_test_config.vocab_emb_dp